From 08b7d99a75ad70fdc5c1d9e5debc675fce5e83a9 Mon Sep 17 00:00:00 2001 From: nihui Date: Wed, 8 May 2024 19:25:51 +0800 Subject: [PATCH] rnn/lstm/gru dynamic quantization (#5435) --- .ci/test-coverage.yml | 1 + cmake/ncnn_add_layer.cmake | 20 +- .../quantized-int8-inference.md | 6 + src/CMakeLists.txt | 8 +- src/layer/arm/gru_arm.cpp | 468 ++- src/layer/arm/gru_arm.h | 14 +- src/layer/arm/gru_arm_asimddp.cpp | 35 + src/layer/arm/gru_arm_asimdhp.cpp | 823 ++--- src/layer/arm/gru_arm_vfpv4.cpp | 30 + src/layer/arm/gru_int8.h | 1405 ++++++++ src/layer/arm/lstm_arm.cpp | 443 ++- src/layer/arm/lstm_arm.h | 14 +- src/layer/arm/lstm_arm_asimddp.cpp | 35 + src/layer/arm/lstm_arm_asimdhp.cpp | 765 ++-- src/layer/arm/lstm_arm_vfpv4.cpp | 30 + src/layer/arm/lstm_int8.h | 844 +++++ src/layer/arm/neon_mathfun.h | 4 +- src/layer/arm/rnn_arm.cpp | 436 ++- src/layer/arm/rnn_arm.h | 14 +- src/layer/arm/rnn_arm_asimddp.cpp | 35 + src/layer/arm/rnn_arm_asimdhp.cpp | 447 +-- src/layer/arm/rnn_arm_vfpv4.cpp | 30 + src/layer/arm/rnn_int8.h | 769 ++++ src/layer/gru.cpp | 296 +- src/layer/gru.h | 7 + src/layer/lstm.cpp | 338 +- src/layer/lstm.h | 7 + src/layer/riscv/gru_riscv.cpp | 22 + src/layer/rnn.cpp | 235 +- src/layer/rnn.h | 7 + src/layer/x86/lstm_int8.h | 3163 +++++++++++++++++ src/layer/x86/lstm_x86.cpp | 272 +- src/layer/x86/lstm_x86.h | 14 + src/layer/x86/lstm_x86_avx2.cpp | 35 + src/layer/x86/lstm_x86_avx512vnni.cpp | 35 + src/layer/x86/lstm_x86_avxvnni.cpp | 35 + src/layer/x86/lstm_x86_xop.cpp | 30 + src/layer/x86/x86_usability.h | 3 + tests/test_gru.cpp | 431 ++- tests/test_lstm.cpp | 502 ++- tests/test_rnn.cpp | 435 ++- tests/testutil.cpp | 24 + tools/modelwriter.h | 30 + tools/quantize/ncnn2int8.cpp | 258 +- 44 files changed, 11129 insertions(+), 1726 deletions(-) create mode 100644 src/layer/arm/gru_arm_asimddp.cpp create mode 100644 src/layer/arm/gru_arm_vfpv4.cpp create mode 100644 src/layer/arm/gru_int8.h create mode 100644 src/layer/arm/lstm_arm_asimddp.cpp create mode 100644 src/layer/arm/lstm_arm_vfpv4.cpp create mode 100644 src/layer/arm/lstm_int8.h create mode 100644 src/layer/arm/rnn_arm_asimddp.cpp create mode 100644 src/layer/arm/rnn_arm_vfpv4.cpp create mode 100644 src/layer/arm/rnn_int8.h create mode 100644 src/layer/x86/lstm_int8.h create mode 100644 src/layer/x86/lstm_x86_avx2.cpp create mode 100644 src/layer/x86/lstm_x86_avx512vnni.cpp create mode 100644 src/layer/x86/lstm_x86_avxvnni.cpp create mode 100644 src/layer/x86/lstm_x86_xop.cpp diff --git a/.ci/test-coverage.yml b/.ci/test-coverage.yml index 0b016be1574..f46bf6e3621 100644 --- a/.ci/test-coverage.yml +++ b/.ci/test-coverage.yml @@ -187,6 +187,7 @@ jobs: - { SSE2: 'ON', AVX: 'OFF', XOP: 'OFF', F16C: 'OFF', FMA: 'OFF', AVX2: 'OFF', AVX512: 'OFF', AVX512VNNI: 'OFF', AVXVNNI: 'OFF', AVX512BF16: 'OFF', AVX512FP16: 'OFF'} - { SSE2: 'ON', AVX: 'ON', XOP: 'OFF', F16C: 'OFF', FMA: 'OFF', AVX2: 'OFF', AVX512: 'OFF', AVX512VNNI: 'OFF', AVXVNNI: 'OFF', AVX512BF16: 'OFF', AVX512FP16: 'OFF'} - { SSE2: 'ON', AVX: 'ON', XOP: 'OFF', F16C: 'ON', FMA: 'ON', AVX2: 'ON', AVX512: 'OFF', AVX512VNNI: 'OFF', AVXVNNI: 'OFF', AVX512BF16: 'OFF', AVX512FP16: 'OFF'} + - { SSE2: 'ON', AVX: 'ON', XOP: 'OFF', F16C: 'ON', FMA: 'ON', AVX2: 'ON', AVX512: 'ON', AVX512VNNI: 'OFF', AVXVNNI: 'OFF', AVX512BF16: 'OFF', AVX512FP16: 'OFF'} - { SSE2: 'ON', AVX: 'ON', XOP: 'OFF', F16C: 'ON', FMA: 'ON', AVX2: 'ON', AVX512: 'ON', AVX512VNNI: 'ON', AVXVNNI: 'OFF', AVX512BF16: 'OFF', AVX512FP16: 'OFF'} runs-on: diff --git a/cmake/ncnn_add_layer.cmake b/cmake/ncnn_add_layer.cmake index a41c52e4ed2..01122487528 100644 --- a/cmake/ncnn_add_layer.cmake +++ b/cmake/ncnn_add_layer.cmake @@ -136,34 +136,34 @@ macro(ncnn_add_layer class) if(NCNN_TARGET_ARCH STREQUAL "x86") if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) if(NCNN_RUNTIME_CPU AND NCNN_AVX512) - ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__") + ncnn_add_arch_opt_layer(${class} avx512 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() if(NCNN_RUNTIME_CPU AND NCNN_FMA) - ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__SSE4_1__ /D__FMA__ /D__F16C__") + ncnn_add_arch_opt_layer(${class} fma "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() if(NCNN_RUNTIME_CPU AND NCNN_AVX) - ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSE4_1__") + ncnn_add_arch_opt_layer(${class} avx "/arch:AVX /D__SSSE3__ /D__SSE4_1__") endif() if(NCNN_AVX512VNNI) - ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") + ncnn_add_arch_opt_source(${class} avx512vnni "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512VNNI__") endif() if(NCNN_AVX512BF16) - ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__") + ncnn_add_arch_opt_source(${class} avx512bf16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512BF16__") endif() if(NCNN_AVX512FP16) - ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__") + ncnn_add_arch_opt_source(${class} avx512fp16 "/arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVX512FP16__") endif() if(NCNN_AVXVNNI) - ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") + ncnn_add_arch_opt_source(${class} avxvnni "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__ /D__AVXVNNI__") endif() if(NCNN_AVX2) - ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSE4_1__ /D__FMA__ /D__F16C__") + ncnn_add_arch_opt_source(${class} avx2 "/arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__") endif() if(NCNN_XOP) - ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSE4_1__ /D__XOP__") + ncnn_add_arch_opt_source(${class} xop "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__XOP__") endif() if(NCNN_F16C) - ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSE4_1__ /D__F16C__") + ncnn_add_arch_opt_source(${class} f16c "/arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__F16C__") endif() else() if(NCNN_RUNTIME_CPU AND NCNN_AVX512) diff --git a/docs/how-to-use-and-FAQ/quantized-int8-inference.md b/docs/how-to-use-and-FAQ/quantized-int8-inference.md index cf8e05c2095..9b51f7b68ea 100644 --- a/docs/how-to-use-and-FAQ/quantized-int8-inference.md +++ b/docs/how-to-use-and-FAQ/quantized-int8-inference.md @@ -48,6 +48,12 @@ If your model has multiple input nodes, you can use multiple list files and othe ./ncnn2int8 mobilenet-opt.param mobilenet-opt.bin mobilenet-int8.param mobilenet-int8.bin mobilenet.table ``` +If you don’t need static quantization, ncnn supports RNN/LSTM/GRU dynamic quantization. In this case, you can omit the table file. + +```shell +./ncnn2int8 rnn-model.param rnn-model.bin rnn-model-int8.param rnn-model-int8.bin +``` + ## use ncnn int8 inference the ncnn library would use int8 inference automatically, nothing changed in your code diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 91041813abd..04ebeb06d14 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -404,7 +404,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") if(NOT NCNN_RUNTIME_CPU AND NCNN_AVX512) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) - target_compile_options(ncnn PRIVATE /arch:AVX512 /D__SSE4_1__ /D__FMA__ /D__F16C__) + target_compile_options(ncnn PRIVATE /arch:AVX512 /D__SSSE3__ /D__SSE4_1__ /D__FMA__ /D__F16C__) if(NCNN_AVX512VNNI) target_compile_options(ncnn PRIVATE /D__AVX512VNNI__) endif() @@ -423,9 +423,9 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") elseif(NOT NCNN_RUNTIME_CPU AND NCNN_FMA) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) if(NCNN_AVX2) - target_compile_options(ncnn PRIVATE /arch:AVX2 /D__SSE4_1__ /D__FMA__) + target_compile_options(ncnn PRIVATE /arch:AVX2 /D__SSSE3__ /D__SSE4_1__ /D__FMA__) else() - target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__ /D__FMA__) + target_compile_options(ncnn PRIVATE /arch:AVX /D__SSSE3__ /D__SSE4_1__ /D__FMA__) endif() if(NCNN_AVXVNNI) target_compile_options(ncnn PRIVATE /D__AVXVNNI__) @@ -452,7 +452,7 @@ if(NCNN_TARGET_ARCH STREQUAL "x86") endif() elseif(NOT NCNN_RUNTIME_CPU AND NCNN_AVX) if(CMAKE_CXX_COMPILER_ID MATCHES "MSVC" OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang" AND CMAKE_CXX_SIMULATE_ID MATCHES "MSVC" AND CMAKE_CXX_COMPILER_FRONTEND_VARIANT MATCHES "MSVC")) - target_compile_options(ncnn PRIVATE /arch:AVX /D__SSE4_1__) + target_compile_options(ncnn PRIVATE /arch:AVX /D__SSSE3__ /D__SSE4_1__) if(NCNN_XOP) target_compile_options(ncnn PRIVATE /D__XOP__) endif() diff --git a/src/layer/arm/gru_arm.cpp b/src/layer/arm/gru_arm.cpp index 80f8c80ad3c..a42241947c1 100644 --- a/src/layer/arm/gru_arm.cpp +++ b/src/layer/arm/gru_arm.cpp @@ -25,6 +25,10 @@ namespace ncnn { +#if NCNN_INT8 +#include "gru_int8.h" +#endif + GRU_arm::GRU_arm() { #if __ARM_NEON @@ -40,6 +44,13 @@ GRU_arm::GRU_arm() int GRU_arm::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage) { @@ -55,8 +66,8 @@ int GRU_arm::create_pipeline(const Option& opt) #endif // pack RUN - int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / num_output / 3; + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / num_output / 3; #if __ARM_NEON weight_xc_data_packed.create(size * 12, num_output / 4 + num_output % 4, num_directions); @@ -90,22 +101,10 @@ int GRU_arm::create_pipeline(const Option& opt) #if __ARM_NEON for (; q + 3 < num_output; q += 4) { - bias_c_RUBNWN[0] = bias_c_R[q]; - bias_c_RUBNWN[1] = bias_c_R[q + 1]; - bias_c_RUBNWN[2] = bias_c_R[q + 2]; - bias_c_RUBNWN[3] = bias_c_R[q + 3]; - bias_c_RUBNWN[4] = bias_c_U[q]; - bias_c_RUBNWN[5] = bias_c_U[q + 1]; - bias_c_RUBNWN[6] = bias_c_U[q + 2]; - bias_c_RUBNWN[7] = bias_c_U[q + 3]; - bias_c_RUBNWN[8] = bias_c_BN[q]; - bias_c_RUBNWN[9] = bias_c_BN[q + 1]; - bias_c_RUBNWN[10] = bias_c_BN[q + 2]; - bias_c_RUBNWN[11] = bias_c_BN[q + 3]; - bias_c_RUBNWN[12] = bias_c_WN[q]; - bias_c_RUBNWN[13] = bias_c_WN[q + 1]; - bias_c_RUBNWN[14] = bias_c_WN[q + 2]; - bias_c_RUBNWN[15] = bias_c_WN[q + 3]; + vst1q_f32(bias_c_RUBNWN, vld1q_f32(bias_c_R + q)); + vst1q_f32(bias_c_RUBNWN + 4, vld1q_f32(bias_c_U + q)); + vst1q_f32(bias_c_RUBNWN + 8, vld1q_f32(bias_c_BN + q)); + vst1q_f32(bias_c_RUBNWN + 12, vld1q_f32(bias_c_WN + q)); bias_c_RUBNWN += 16; @@ -637,16 +636,18 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blob, top_blob, opt); + } +#endif + int elembits = bottom_blob.elembits(); #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blob, top_blob, opt); - else - return forward_fp16s(bottom_blob, top_blob, opt); - } + return forward_fp16s(bottom_blob, top_blob, opt); #endif #if NCNN_BF16 @@ -686,15 +687,19 @@ int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c if (top_blob_reverse.empty()) return -100; - int ret0 = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; + { + int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); - int ret1 = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; + { + int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -713,17 +718,19 @@ int GRU_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& bottom_blob = bottom_blobs[0]; int elembits = bottom_blob.elembits(); #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blobs, top_blobs, opt); - else - return forward_fp16s(bottom_blobs, top_blobs, opt); - } + return forward_fp16s(bottom_blobs, top_blobs, opt); #endif #if NCNN_BF16 @@ -772,14 +779,18 @@ int GRU_arm::forward(const std::vector& bottom_blobs, std::vector& top return -100; Mat hidden0 = hidden.row_range(0, 1); - int ret0 = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; + { + int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); - int ret1 = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; + { + int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -1215,22 +1226,10 @@ int GRU_arm::create_pipeline_bf16s(const Option& opt) #if __ARM_NEON for (; q + 3 < num_output; q += 4) { - bias_c_RUBNWN[0] = float32_to_bfloat16(bias_c_R[q]); - bias_c_RUBNWN[1] = float32_to_bfloat16(bias_c_R[q + 1]); - bias_c_RUBNWN[2] = float32_to_bfloat16(bias_c_R[q + 2]); - bias_c_RUBNWN[3] = float32_to_bfloat16(bias_c_R[q + 3]); - bias_c_RUBNWN[4] = float32_to_bfloat16(bias_c_U[q]); - bias_c_RUBNWN[5] = float32_to_bfloat16(bias_c_U[q + 1]); - bias_c_RUBNWN[6] = float32_to_bfloat16(bias_c_U[q + 2]); - bias_c_RUBNWN[7] = float32_to_bfloat16(bias_c_U[q + 3]); - bias_c_RUBNWN[8] = float32_to_bfloat16(bias_c_BN[q]); - bias_c_RUBNWN[9] = float32_to_bfloat16(bias_c_BN[q + 1]); - bias_c_RUBNWN[10] = float32_to_bfloat16(bias_c_BN[q + 2]); - bias_c_RUBNWN[11] = float32_to_bfloat16(bias_c_BN[q + 3]); - bias_c_RUBNWN[12] = float32_to_bfloat16(bias_c_WN[q]); - bias_c_RUBNWN[13] = float32_to_bfloat16(bias_c_WN[q + 1]); - bias_c_RUBNWN[14] = float32_to_bfloat16(bias_c_WN[q + 2]); - bias_c_RUBNWN[15] = float32_to_bfloat16(bias_c_WN[q + 3]); + vst1_u16(bias_c_RUBNWN, float2bfloat(vld1q_f32(bias_c_R + q))); + vst1_u16(bias_c_RUBNWN + 4, float2bfloat(vld1q_f32(bias_c_U + q))); + vst1_u16(bias_c_RUBNWN + 8, float2bfloat(vld1q_f32(bias_c_BN + q))); + vst1_u16(bias_c_RUBNWN + 12, float2bfloat(vld1q_f32(bias_c_WN + q))); bias_c_RUBNWN += 16; @@ -1419,15 +1418,19 @@ int GRU_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = gru_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; + { + int ret = gru_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); - int ret1 = gru_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; + { + int ret = gru_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -1490,14 +1493,18 @@ int GRU_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(float16_to_float32(x[i]))); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } + if (elemtype == 4) + { + // bf16 + for (int t = 0; t < T; t++) + { + const unsigned short* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(x[i]))); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } + + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt); +} + +int GRU_arm::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); + + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 + } + + // *INDENT-ON* + // clang-format on + } + + int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; + + int num_directions = direction == 2 ? 2 : 1; + + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + top_blob.create(num_output * num_directions, T, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, elemtype, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + { + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, elemtype, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); + } + + hidden.fill(0.f); + + { + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, elemtype, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), hidden, opt); + } + + // concat w + for (int i = 0; i < T; i++) + { + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); + } + } + + return 0; +} + +int GRU_arm::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); + + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 + } + + // *INDENT-ON* + // clang-format on + } + + int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + if (elemtype == 1) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + if (elemtype == 2) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); + } + if (elemtype == 4) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt_cast); + } + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, elemtype, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + { + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, elemtype, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden0, opt); + } + + Mat hidden1 = hidden.row_range(1, 1); + { + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, elemtype, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), hidden1, opt); + } + + // concat w + for (int i = 0; i < T; i++) + { + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); + } + } + + if (top_blobs.size() == 2) + { + if (elemtype == 1) + { + top_blobs[1] = hidden; + } + if (elemtype == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } + if (elemtype == 4) + { + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + } + } + + return 0; +} +#endif // NCNN_INT8 + } // namespace ncnn diff --git a/src/layer/arm/gru_arm.h b/src/layer/arm/gru_arm.h index 6eae1656b01..aba1608df90 100644 --- a/src/layer/arm/gru_arm.h +++ b/src/layer/arm/gru_arm.h @@ -33,19 +33,29 @@ class GRU_arm : public GRU int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - int forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; - int forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif #if NCNN_BF16 int create_pipeline_bf16s(const Option& opt); int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + void dynamic_quantize(const Mat& bottom_blob, int elemtype, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const; + int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif public: Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + + Mat weight_data_tm; + +#if NCNN_INT8 + Mat weight_data_tm_int8_descales; +#endif }; } // namespace ncnn diff --git a/src/layer/arm/gru_arm_asimddp.cpp b/src/layer/arm/gru_arm_asimddp.cpp new file mode 100644 index 00000000000..3de7ed84ead --- /dev/null +++ b/src/layer/arm/gru_arm_asimddp.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "arm_activation.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "gru_int8.h" + +void gru_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt) +{ + gru_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, opt); +} + +void gru_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt) +{ + gru_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, hidden_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/arm/gru_arm_asimdhp.cpp b/src/layer/arm/gru_arm_asimdhp.cpp index f3d38305a2e..3a3d92d5d57 100644 --- a/src/layer/arm/gru_arm_asimdhp.cpp +++ b/src/layer/arm/gru_arm_asimdhp.cpp @@ -23,7 +23,7 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; int T = bottom_blob.h; @@ -55,177 +55,253 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M const __fp16* weight_xc_RUN = weight_xc.row(q / 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4); - float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); - float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - float32x4_t _sum4 = vdupq_n_f32(0.f); - float32x4_t _sum5 = vdupq_n_f32(0.f); - float32x4_t _sum6 = vdupq_n_f32(0.f); + float16x8_t _RU = vld1q_f16(bias_c_RUBNWN); + float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); + float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); int i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); - float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); - float32x4_t _weight_xc_R_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); - float32x4_t _weight_xc_U_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); - float32x4_t _weight_xc_R_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 16)); - float32x4_t _weight_xc_U_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 20)); - float32x4_t _weight_xc_R_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 24)); - float32x4_t _weight_xc_U_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 28)); - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4h}, [%0], #8 \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_RUN), + "=w"(_RU), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(x), + "1"(weight_xc_RUN), + "2"(_RU), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); + float16x8_t _w0 = vld1q_f16(weight_xc_RUN); + float16x8_t _w1 = vld1q_f16(weight_xc_RUN + 8); + float16x8_t _w2 = vld1q_f16(weight_xc_RUN + 16); + float16x8_t _w3 = vld1q_f16(weight_xc_RUN + 24); + _RU = vfmaq_lane_f16(_RU, _w0, _x, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); + x += 4; weight_xc_RUN += 32; +#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = x[i]; + __fp16 xi = *x++; - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); - _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); - _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); + float16x8_t _xi = vdupq_n_f16(xi); + float16x8_t _weight_xc_RU = vld1q_f16(weight_xc_RUN); + _RU = vfmaq_f16(_RU, _weight_xc_RU, _xi); weight_xc_RUN += 8; } + const float* hidden_ptr = hidden_state; + i = 0; for (; i + 3 < num_output; i += 4) { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); - float32x4_t _weight_hc_R_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); - float32x4_t _weight_hc_U_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); - float32x4_t _weight_hc_R_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 16)); - float32x4_t _weight_hc_U_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 20)); - float32x4_t _weight_hc_R_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 24)); - float32x4_t _weight_hc_U_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 28)); - _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); - _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); - _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); - _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); - _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4s}, [%0], #16 \n" + "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" + "fcvtn v4.4h, v4.4s \n" + "fmla %2.8h, v0.8h, v4.h[0] \n" + "fmla %3.8h, v1.8h, v4.h[1] \n" + "fmla %4.8h, v2.8h, v4.h[2] \n" + "fmla %5.8h, v3.8h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_RUN), + "=w"(_RU), + "=w"(_sum1), + "=w"(_sum2), + "=w"(_sum3) + : "0"(hidden_ptr), + "1"(weight_hc_RUN), + "2"(_RU), + "3"(_sum1), + "4"(_sum2), + "5"(_sum3) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + float16x8_t _w0 = vld1q_f16(weight_hc_RUN); + float16x8_t _w1 = vld1q_f16(weight_hc_RUN + 8); + float16x8_t _w2 = vld1q_f16(weight_hc_RUN + 16); + float16x8_t _w3 = vld1q_f16(weight_hc_RUN + 24); + _RU = vfmaq_lane_f16(_RU, _w0, _h_cont, 0); + _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); + _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); + _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); + hidden_ptr += 4; weight_hc_RUN += 32; +#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = hidden_state[i]; + float h_cont = *hidden_ptr++; - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); - _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); - _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); + float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); + float16x8_t _weight_hc_RU = vld1q_f16(weight_hc_RUN); + _RU = vfmaq_f16(_RU, _weight_hc_RU, _h_cont); weight_hc_RUN += 8; } - _gru_R = vaddq_f32(_gru_R, _sum1); - _gru_U = vaddq_f32(_gru_U, _sum2); - _sum3 = vaddq_f32(_sum3, _sum5); - _sum4 = vaddq_f32(_sum4, _sum6); - _gru_R = vaddq_f32(_gru_R, _sum3); - _gru_U = vaddq_f32(_gru_U, _sum4); + _RU = vaddq_f16(_RU, _sum1); + _sum2 = vaddq_f16(_sum2, _sum3); + _RU = vaddq_f16(_RU, _sum2); // sigmoid(R) // sigmoid(U) - _gru_R = sigmoid_ps(_gru_R); - _gru_U = sigmoid_ps(_gru_U); + float32x4_t _R32 = sigmoid_ps(vcvt_f32_f16(vget_low_f16(_RU))); + float32x4_t _U32 = sigmoid_ps(vcvt_f32_f16(vget_high_f16(_RU))); + + x -= size; + hidden_ptr = hidden_state; // gate new - float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); + float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); + float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); + float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); i = 0; for (; i + 3 < num_output; i += 4) { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - float32x4_t _weight_hc_N_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); - float32x4_t _weight_hc_N_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); - float32x4_t _weight_hc_N_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4s}, [%0], #16 \n" + "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" + "fcvtn v4.4h, v4.4s \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(hidden_ptr), + "=r"(weight_hc_RUN), + "=w"(_gru_N), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6) + : "0"(hidden_ptr), + "1"(weight_hc_RUN), + "2"(_gru_N), + "3"(_sum4), + "4"(_sum5), + "5"(_sum6) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); + float16x4_t _w0 = vld1_f16(weight_hc_RUN); + float16x4_t _w1 = vld1_f16(weight_hc_RUN + 4); + float16x4_t _w2 = vld1_f16(weight_hc_RUN + 8); + float16x4_t _w3 = vld1_f16(weight_hc_RUN + 12); + _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); + _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); + _sum5 = vfma_lane_f16(_sum5, _w2, _h_cont, 2); + _sum6 = vfma_lane_f16(_sum6, _w3, _h_cont, 3); + hidden_ptr += 4; weight_hc_RUN += 16; +#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = hidden_state[i]; + float h_cont = *hidden_ptr++; - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); - _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); + float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); + float16x4_t _weight_hc_N = vld1_f16(weight_hc_RUN); + _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); weight_hc_RUN += 4; } - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); + _gru_N = vadd_f16(_gru_N, _sum4); + _sum5 = vadd_f16(_sum5, _sum6); + _gru_N = vadd_f16(_gru_N, _sum5); - _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); - _sum1 = vdupq_n_f32(0.f); - _sum2 = vdupq_n_f32(0.f); - _sum3 = vdupq_n_f32(0.f); + _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); + _sum4 = vdup_n_f16((__fp16)0.f); + _sum5 = vdup_n_f16((__fp16)0.f); + _sum6 = vdup_n_f16((__fp16)0.f); i = 0; for (; i + 3 < size; i += 4) { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); - float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - float32x4_t _weight_xc_N_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); - float32x4_t _weight_xc_N_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); - float32x4_t _weight_xc_N_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); - _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); +#if NCNN_GNU_INLINE_ASM + asm volatile( + "ld1 {v4.4h}, [%0], #8 \n" + "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" + "fmla %2.4h, v0.4h, v4.h[0] \n" + "fmla %3.4h, v1.4h, v4.h[1] \n" + "fmla %4.4h, v2.4h, v4.h[2] \n" + "fmla %5.4h, v3.4h, v4.h[3] \n" + : "=r"(x), + "=r"(weight_xc_RUN), + "=w"(_gru_N), + "=w"(_sum4), + "=w"(_sum5), + "=w"(_sum6) + : "0"(x), + "1"(weight_xc_RUN), + "2"(_gru_N), + "3"(_sum4), + "4"(_sum5), + "5"(_sum6) + : "memory", "v0", "v1", "v2", "v3", "v4"); +#else // NCNN_GNU_INLINE_ASM + float16x4_t _x = vld1_f16(x); + float16x4_t _w0 = vld1_f16(weight_xc_RUN); + float16x4_t _w1 = vld1_f16(weight_xc_RUN + 4); + float16x4_t _w2 = vld1_f16(weight_xc_RUN + 8); + float16x4_t _w3 = vld1_f16(weight_xc_RUN + 12); + _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); + _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1); + _sum5 = vfma_lane_f16(_sum5, _w2, _x, 2); + _sum6 = vfma_lane_f16(_sum6, _w3, _x, 3); + x += 4; weight_xc_RUN += 16; +#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = x[i]; + __fp16 xi = *x++; - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); - _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); + float16x4_t _xi = vdup_n_f16(xi); + float16x4_t _weight_xc_N = vld1_f16(weight_xc_RUN); + _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); weight_xc_RUN += 4; } - _gru_N = vaddq_f32(_gru_N, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _gru_N = vaddq_f32(_gru_N, _sum2); + _gru_N = vadd_f16(_gru_N, _sum4); + _sum5 = vadd_f16(_sum5, _sum6); + _gru_N = vadd_f16(_gru_N, _sum5); // tanh(N) - _gru_N = tanh_ps(_gru_N); + float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); float* gates_data = gates.row(q / 4); - vst1q_f32(gates_data, _gru_U); - vst1q_f32(gates_data + 4, _gru_N); + vst1q_f32(gates_data, _U32); + vst1q_f32(gates_data + 4, _N32); } #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) @@ -238,64 +314,64 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M const __fp16* weight_xc_RUN = weight_xc.row(q / 4 + q % 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4 + q % 4); - float R = (float)bias_c_RUBNWN[0]; - float U = (float)bias_c_RUBNWN[1]; + __fp16 R = bias_c_RUBNWN[0]; + __fp16 U = bias_c_RUBNWN[1]; for (int i = 0; i < size; i++) { - float xi = (float)x[i]; + __fp16 xi = x[i]; - R += (float)weight_xc_RUN[0] * xi; - U += (float)weight_xc_RUN[1] * xi; + R += weight_xc_RUN[0] * xi; + U += weight_xc_RUN[1] * xi; weight_xc_RUN += 2; } for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + __fp16 h_cont = (__fp16)hidden_state[i]; - R += (float)weight_hc_RUN[0] * h_cont; - U += (float)weight_hc_RUN[1] * h_cont; + R += weight_hc_RUN[0] * h_cont; + U += weight_hc_RUN[1] * h_cont; weight_hc_RUN += 2; } // sigmoid(R) // sigmoid(U) - R = 1.f / (1.f + expf(-R)); - U = 1.f / (1.f + expf(-U)); + float R32 = 1.f / (1.f + expf((float)-R)); + float U32 = 1.f / (1.f + expf((float)-U)); // gate new - float N = (float)bias_c_RUBNWN[2]; + __fp16 N = bias_c_RUBNWN[2]; for (int i = 0; i < num_output; i++) { - float h_cont = hidden_state[i]; + __fp16 h_cont = (__fp16)hidden_state[i]; - N += (float)weight_hc_RUN[0] * h_cont; + N += weight_hc_RUN[0] * h_cont; weight_hc_RUN += 1; } - N = (float)bias_c_RUBNWN[3] + R * N; + N = bias_c_RUBNWN[3] + (__fp16)R32 * N; for (int i = 0; i < size; i++) { - float xi = (float)x[i]; + __fp16 xi = x[i]; - N += (float)weight_xc_RUN[0] * xi; + N += weight_xc_RUN[0] * xi; weight_xc_RUN += 1; } // tanh(N) - N = tanhf(N); + float N32 = tanhf((float)N); float* gates_data = gates.row(q / 4 + q % 4); - gates_data[0] = U; - gates_data[1] = N; + gates_data[0] = U32; + gates_data[1] = N32; } // h_t := (1 - update) .* new + update .* h_{t-1} @@ -338,8 +414,11 @@ static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const M return 0; } -static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +static int gru_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { + if (opt.use_fp16_arithmetic) + return gru_fp16sa(bottom_blob, top_blob, reverse, weight_xc, bias_c, weight_hc, hidden_state, opt); + int size = bottom_blob.w; int T = bottom_blob.h; @@ -370,253 +449,177 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const const __fp16* weight_xc_RUN = weight_xc.row(q / 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4); - float16x8_t _RU = vld1q_f16(bias_c_RUBNWN); - float16x8_t _sum1 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum2 = vdupq_n_f16((__fp16)0.f); - float16x8_t _sum3 = vdupq_n_f16((__fp16)0.f); + float32x4_t _gru_R = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN)); + float32x4_t _gru_U = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 4)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + float32x4_t _sum4 = vdupq_n_f32(0.f); + float32x4_t _sum5 = vdupq_n_f32(0.f); + float32x4_t _sum6 = vdupq_n_f32(0.f); int i = 0; for (; i + 3 < size; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4h}, [%0], #8 \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(x), - "1"(weight_xc_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); - float16x8_t _w0 = vld1q_f16(weight_xc_RUN); - float16x8_t _w1 = vld1q_f16(weight_xc_RUN + 8); - float16x8_t _w2 = vld1q_f16(weight_xc_RUN + 16); - float16x8_t _w3 = vld1q_f16(weight_xc_RUN + 24); - _RU = vfmaq_lane_f16(_RU, _w0, _x, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _x, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _x, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _x, 3); + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); + float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); + float32x4_t _weight_xc_R_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); + float32x4_t _weight_xc_U_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); + float32x4_t _weight_xc_R_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 16)); + float32x4_t _weight_xc_U_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 20)); + float32x4_t _weight_xc_R_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 24)); + float32x4_t _weight_xc_U_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 28)); + _gru_R = vfmaq_laneq_f32(_gru_R, _weight_xc_R, _xi, 0); + _gru_U = vfmaq_laneq_f32(_gru_U, _weight_xc_U, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_R_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_U_1, _xi, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_R_2, _xi, 2); + _sum4 = vfmaq_laneq_f32(_sum4, _weight_xc_U_2, _xi, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _weight_xc_R_3, _xi, 3); + _sum6 = vfmaq_laneq_f32(_sum6, _weight_xc_U_3, _xi, 3); - x += 4; weight_xc_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = *x++; + __fp16 xi = x[i]; - float16x8_t _xi = vdupq_n_f16(xi); - float16x8_t _weight_xc_RU = vld1q_f16(weight_xc_RUN); - _RU = vfmaq_f16(_RU, _weight_xc_RU, _xi); + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float32x4_t _weight_xc_R = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + float32x4_t _weight_xc_U = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); + _gru_R = vmlaq_f32(_gru_R, _weight_xc_R, _xi); + _gru_U = vmlaq_f32(_gru_U, _weight_xc_U, _xi); weight_xc_RUN += 8; } - const float* hidden_ptr = hidden_state; - i = 0; for (; i + 3 < num_output; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4s}, [%0], #16 \n" - "ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [%1], #64 \n" - "fcvtn v4.4h, v4.4s \n" - "fmla %2.8h, v0.8h, v4.h[0] \n" - "fmla %3.8h, v1.8h, v4.h[1] \n" - "fmla %4.8h, v2.8h, v4.h[2] \n" - "fmla %5.8h, v3.8h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_RUN), - "=w"(_RU), - "=w"(_sum1), - "=w"(_sum2), - "=w"(_sum3) - : "0"(hidden_ptr), - "1"(weight_hc_RUN), - "2"(_RU), - "3"(_sum1), - "4"(_sum2), - "5"(_sum3) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); - float16x8_t _w0 = vld1q_f16(weight_hc_RUN); - float16x8_t _w1 = vld1q_f16(weight_hc_RUN + 8); - float16x8_t _w2 = vld1q_f16(weight_hc_RUN + 16); - float16x8_t _w3 = vld1q_f16(weight_hc_RUN + 24); - _RU = vfmaq_lane_f16(_RU, _w0, _h_cont, 0); - _sum1 = vfmaq_lane_f16(_sum1, _w1, _h_cont, 1); - _sum2 = vfmaq_lane_f16(_sum2, _w2, _h_cont, 2); - _sum3 = vfmaq_lane_f16(_sum3, _w3, _h_cont, 3); + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); + float32x4_t _weight_hc_R_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); + float32x4_t _weight_hc_U_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); + float32x4_t _weight_hc_R_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 16)); + float32x4_t _weight_hc_U_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 20)); + float32x4_t _weight_hc_R_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 24)); + float32x4_t _weight_hc_U_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 28)); + _gru_R = vfmaq_laneq_f32(_gru_R, _weight_hc_R, _h_cont, 0); + _gru_U = vfmaq_laneq_f32(_gru_U, _weight_hc_U, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_R_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_U_1, _h_cont, 1); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_R_2, _h_cont, 2); + _sum4 = vfmaq_laneq_f32(_sum4, _weight_hc_U_2, _h_cont, 2); + _sum5 = vfmaq_laneq_f32(_sum5, _weight_hc_R_3, _h_cont, 3); + _sum6 = vfmaq_laneq_f32(_sum6, _weight_hc_U_3, _h_cont, 3); - hidden_ptr += 4; weight_hc_RUN += 32; -#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = *hidden_ptr++; + float h_cont = hidden_state[i]; - float16x8_t _h_cont = vdupq_n_f16((__fp16)h_cont); - float16x8_t _weight_hc_RU = vld1q_f16(weight_hc_RUN); - _RU = vfmaq_f16(_RU, _weight_hc_RU, _h_cont); + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_R = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + float32x4_t _weight_hc_U = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); + _gru_R = vmlaq_f32(_gru_R, _weight_hc_R, _h_cont); + _gru_U = vmlaq_f32(_gru_U, _weight_hc_U, _h_cont); weight_hc_RUN += 8; } - _RU = vaddq_f16(_RU, _sum1); - _sum2 = vaddq_f16(_sum2, _sum3); - _RU = vaddq_f16(_RU, _sum2); + _gru_R = vaddq_f32(_gru_R, _sum1); + _gru_U = vaddq_f32(_gru_U, _sum2); + _sum3 = vaddq_f32(_sum3, _sum5); + _sum4 = vaddq_f32(_sum4, _sum6); + _gru_R = vaddq_f32(_gru_R, _sum3); + _gru_U = vaddq_f32(_gru_U, _sum4); // sigmoid(R) // sigmoid(U) - float32x4_t _R32 = sigmoid_ps(vcvt_f32_f16(vget_low_f16(_RU))); - float32x4_t _U32 = sigmoid_ps(vcvt_f32_f16(vget_high_f16(_RU))); - - x -= size; - hidden_ptr = hidden_state; + _gru_R = sigmoid_ps(_gru_R); + _gru_U = sigmoid_ps(_gru_U); // gate new - float16x4_t _gru_N = vld1_f16(bias_c_RUBNWN + 8); - float16x4_t _sum4 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum5 = vdup_n_f16((__fp16)0.f); - float16x4_t _sum6 = vdup_n_f16((__fp16)0.f); + float32x4_t _gru_N = vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 8)); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); i = 0; - for (; i + 3 < num_output; i += 4) - { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4s}, [%0], #16 \n" - "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" - "fcvtn v4.4h, v4.4s \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(hidden_ptr), - "=r"(weight_hc_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(hidden_ptr), - "1"(weight_hc_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _h_cont = vcvt_f16_f32(vld1q_f32(hidden_ptr)); - float16x4_t _w0 = vld1_f16(weight_hc_RUN); - float16x4_t _w1 = vld1_f16(weight_hc_RUN + 4); - float16x4_t _w2 = vld1_f16(weight_hc_RUN + 8); - float16x4_t _w3 = vld1_f16(weight_hc_RUN + 12); - _gru_N = vfma_lane_f16(_gru_N, _w0, _h_cont, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _h_cont, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _h_cont, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _h_cont, 3); + for (; i + 3 < num_output; i += 4) + { + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + float32x4_t _weight_hc_N_1 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 4)); + float32x4_t _weight_hc_N_2 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 8)); + float32x4_t _weight_hc_N_3 = vcvt_f32_f16(vld1_f16(weight_hc_RUN + 12)); + _gru_N = vfmaq_laneq_f32(_gru_N, _weight_hc_N, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_N_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_N_2, _h_cont, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_N_3, _h_cont, 3); - hidden_ptr += 4; weight_hc_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM } for (; i < num_output; i++) { - float h_cont = *hidden_ptr++; + float h_cont = hidden_state[i]; - float16x4_t _h_cont = vdup_n_f16((__fp16)h_cont); - float16x4_t _weight_hc_N = vld1_f16(weight_hc_RUN); - _gru_N = vfma_f16(_gru_N, _weight_hc_N, _h_cont); + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_N = vcvt_f32_f16(vld1_f16(weight_hc_RUN)); + _gru_N = vmlaq_f32(_gru_N, _weight_hc_N, _h_cont); weight_hc_RUN += 4; } - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); + _gru_N = vaddq_f32(_gru_N, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _gru_N = vaddq_f32(_gru_N, _sum2); - _gru_N = vfma_f16(vld1_f16(bias_c_RUBNWN + 12), vcvt_f16_f32(_R32), _gru_N); - _sum4 = vdup_n_f16((__fp16)0.f); - _sum5 = vdup_n_f16((__fp16)0.f); - _sum6 = vdup_n_f16((__fp16)0.f); + _gru_N = vmlaq_f32(vcvt_f32_f16(vld1_f16(bias_c_RUBNWN + 12)), _gru_R, _gru_N); + _sum1 = vdupq_n_f32(0.f); + _sum2 = vdupq_n_f32(0.f); + _sum3 = vdupq_n_f32(0.f); i = 0; for (; i + 3 < size; i += 4) { -#if NCNN_GNU_INLINE_ASM - asm volatile( - "ld1 {v4.4h}, [%0], #8 \n" - "ld1 {v0.4h, v1.4h, v2.4h, v3.4h}, [%1], #32 \n" - "fmla %2.4h, v0.4h, v4.h[0] \n" - "fmla %3.4h, v1.4h, v4.h[1] \n" - "fmla %4.4h, v2.4h, v4.h[2] \n" - "fmla %5.4h, v3.4h, v4.h[3] \n" - : "=r"(x), - "=r"(weight_xc_RUN), - "=w"(_gru_N), - "=w"(_sum4), - "=w"(_sum5), - "=w"(_sum6) - : "0"(x), - "1"(weight_xc_RUN), - "2"(_gru_N), - "3"(_sum4), - "4"(_sum5), - "5"(_sum6) - : "memory", "v0", "v1", "v2", "v3", "v4"); -#else // NCNN_GNU_INLINE_ASM - float16x4_t _x = vld1_f16(x); - float16x4_t _w0 = vld1_f16(weight_xc_RUN); - float16x4_t _w1 = vld1_f16(weight_xc_RUN + 4); - float16x4_t _w2 = vld1_f16(weight_xc_RUN + 8); - float16x4_t _w3 = vld1_f16(weight_xc_RUN + 12); - _gru_N = vfma_lane_f16(_gru_N, _w0, _x, 0); - _sum4 = vfma_lane_f16(_sum4, _w1, _x, 1); - _sum5 = vfma_lane_f16(_sum5, _w2, _x, 2); - _sum6 = vfma_lane_f16(_sum6, _w3, _x, 3); + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); + float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + float32x4_t _weight_xc_N_1 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 4)); + float32x4_t _weight_xc_N_2 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 8)); + float32x4_t _weight_xc_N_3 = vcvt_f32_f16(vld1_f16(weight_xc_RUN + 12)); + _gru_N = vfmaq_laneq_f32(_gru_N, _weight_xc_N, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_N_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_N_2, _xi, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_N_3, _xi, 3); - x += 4; weight_xc_RUN += 16; -#endif // NCNN_GNU_INLINE_ASM } for (; i < size; i++) { - __fp16 xi = *x++; + __fp16 xi = x[i]; - float16x4_t _xi = vdup_n_f16(xi); - float16x4_t _weight_xc_N = vld1_f16(weight_xc_RUN); - _gru_N = vfma_f16(_gru_N, _weight_xc_N, _xi); + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float32x4_t _weight_xc_N = vcvt_f32_f16(vld1_f16(weight_xc_RUN)); + _gru_N = vmlaq_f32(_gru_N, _weight_xc_N, _xi); weight_xc_RUN += 4; } - _gru_N = vadd_f16(_gru_N, _sum4); - _sum5 = vadd_f16(_sum5, _sum6); - _gru_N = vadd_f16(_gru_N, _sum5); + _gru_N = vaddq_f32(_gru_N, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _gru_N = vaddq_f32(_gru_N, _sum2); // tanh(N) - float32x4_t _N32 = tanh_ps(vcvt_f32_f16(_gru_N)); + _gru_N = tanh_ps(_gru_N); float* gates_data = gates.row(q / 4); - vst1q_f32(gates_data, _U32); - vst1q_f32(gates_data + 4, _N32); + vst1q_f32(gates_data, _gru_U); + vst1q_f32(gates_data + 4, _gru_N); } #pragma omp parallel for num_threads(opt.num_threads) for (int q = remain_num_output_start; q < num_output; q++) @@ -629,64 +632,64 @@ static int gru_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const const __fp16* weight_xc_RUN = weight_xc.row(q / 4 + q % 4); const __fp16* weight_hc_RUN = weight_hc.row(q / 4 + q % 4); - __fp16 R = bias_c_RUBNWN[0]; - __fp16 U = bias_c_RUBNWN[1]; + float R = (float)bias_c_RUBNWN[0]; + float U = (float)bias_c_RUBNWN[1]; for (int i = 0; i < size; i++) { - __fp16 xi = x[i]; + float xi = (float)x[i]; - R += weight_xc_RUN[0] * xi; - U += weight_xc_RUN[1] * xi; + R += (float)weight_xc_RUN[0] * xi; + U += (float)weight_xc_RUN[1] * xi; weight_xc_RUN += 2; } for (int i = 0; i < num_output; i++) { - __fp16 h_cont = (__fp16)hidden_state[i]; + float h_cont = hidden_state[i]; - R += weight_hc_RUN[0] * h_cont; - U += weight_hc_RUN[1] * h_cont; + R += (float)weight_hc_RUN[0] * h_cont; + U += (float)weight_hc_RUN[1] * h_cont; weight_hc_RUN += 2; } // sigmoid(R) // sigmoid(U) - float R32 = 1.f / (1.f + expf((float)-R)); - float U32 = 1.f / (1.f + expf((float)-U)); + R = 1.f / (1.f + expf(-R)); + U = 1.f / (1.f + expf(-U)); // gate new - __fp16 N = bias_c_RUBNWN[2]; + float N = (float)bias_c_RUBNWN[2]; for (int i = 0; i < num_output; i++) { - __fp16 h_cont = (__fp16)hidden_state[i]; + float h_cont = hidden_state[i]; - N += weight_hc_RUN[0] * h_cont; + N += (float)weight_hc_RUN[0] * h_cont; weight_hc_RUN += 1; } - N = bias_c_RUBNWN[3] + (__fp16)R32 * N; + N = (float)bias_c_RUBNWN[3] + R * N; for (int i = 0; i < size; i++) { - __fp16 xi = x[i]; + float xi = (float)x[i]; - N += weight_xc_RUN[0] * xi; + N += (float)weight_xc_RUN[0] * xi; weight_xc_RUN += 1; } // tanh(N) - float N32 = tanhf((float)N); + N = tanhf(N); float* gates_data = gates.row(q / 4 + q % 4); - gates_data[0] = U32; - gates_data[1] = N32; + gates_data[0] = U; + gates_data[1] = N; } // h_t := (1 - update) .* new + update .* h_{t-1} @@ -958,15 +961,19 @@ int GRU_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = gru_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; + { + int ret = gru_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); - int ret1 = gru_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; + { + int ret = gru_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -1029,148 +1036,18 @@ int GRU_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector(i); - const __fp16* pr = top_blob_reverse.row(i); - __fp16* ptr = top_blob.row<__fp16>(i); - - memcpy(ptr, pf, num_output * sizeof(__fp16)); - memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + int ret = gru_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret != 0) + return ret; } - } - - if (top_blobs.size() == 2) - { - cast_float32_to_float16(hidden, top_blobs[1], opt); - } - - return 0; -} - -int GRU_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const -{ - int T = bottom_blob.h; - - int num_directions = direction == 2 ? 2 : 1; - - // initial hidden state - Mat hidden(num_output, 4u, opt.workspace_allocator); - if (hidden.empty()) - return -100; - hidden.fill(0.f); - - top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = gru_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - int ret0 = gru_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; - hidden.fill(0.f); - - int ret1 = gru_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; - - // concat w - for (int i = 0; i < T; i++) + Mat hidden1 = hidden.row_range(1, 1); { - const __fp16* pf = top_blob_forward.row(i); - const __fp16* pr = top_blob_reverse.row(i); - __fp16* ptr = top_blob.row<__fp16>(i); - - memcpy(ptr, pf, num_output * sizeof(__fp16)); - memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + int ret = gru_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret != 0) + return ret; } - } - - return 0; -} - -int GRU_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const -{ - const Mat& bottom_blob = bottom_blobs[0]; - int T = bottom_blob.h; - int num_directions = direction == 2 ? 2 : 1; - - Mat hidden; - Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; - if (bottom_blobs.size() == 2) - { - Option opt_cast = opt; - opt_cast.blob_allocator = hidden_allocator; - cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); - } - else - { - hidden.create(num_output, num_directions, 4u, hidden_allocator); - if (hidden.empty()) - return -100; - hidden.fill(0.f); - } - - Mat& top_blob = top_blobs[0]; - top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = gru_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - Mat hidden0 = hidden.row_range(0, 1); - int ret0 = gru_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; - - Mat hidden1 = hidden.row_range(1, 1); - int ret1 = gru_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; // concat w for (int i = 0; i < T; i++) @@ -1186,11 +1063,11 @@ int GRU_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector(num_output * 0 + q); + const signed char* weight_xc_U_0 = weight_xc_dr.row(num_output * 1 + q); + const signed char* weight_xc_N_0 = weight_xc_dr.row(num_output * 2 + q); + + const signed char* weight_xc_R_1 = weight_xc_dr.row(num_output * 0 + q + 1); + const signed char* weight_xc_U_1 = weight_xc_dr.row(num_output * 1 + q + 1); + const signed char* weight_xc_N_1 = weight_xc_dr.row(num_output * 2 + q + 1); + + const signed char* weight_xc_R_2 = weight_xc_dr.row(num_output * 0 + q + 2); + const signed char* weight_xc_U_2 = weight_xc_dr.row(num_output * 1 + q + 2); + const signed char* weight_xc_N_2 = weight_xc_dr.row(num_output * 2 + q + 2); + + const signed char* weight_xc_R_3 = weight_xc_dr.row(num_output * 0 + q + 3); + const signed char* weight_xc_U_3 = weight_xc_dr.row(num_output * 1 + q + 3); + const signed char* weight_xc_N_3 = weight_xc_dr.row(num_output * 2 + q + 3); + + const signed char* weight_hc_R_0 = weight_hc_dr.row(num_output * 0 + q); + const signed char* weight_hc_U_0 = weight_hc_dr.row(num_output * 1 + q); + const signed char* weight_hc_N_0 = weight_hc_dr.row(num_output * 2 + q); + + const signed char* weight_hc_R_1 = weight_hc_dr.row(num_output * 0 + q + 1); + const signed char* weight_hc_U_1 = weight_hc_dr.row(num_output * 1 + q + 1); + const signed char* weight_hc_N_1 = weight_hc_dr.row(num_output * 2 + q + 1); + + const signed char* weight_hc_R_2 = weight_hc_dr.row(num_output * 0 + q + 2); + const signed char* weight_hc_U_2 = weight_hc_dr.row(num_output * 1 + q + 2); + const signed char* weight_hc_N_2 = weight_hc_dr.row(num_output * 2 + q + 2); + + const signed char* weight_hc_R_3 = weight_hc_dr.row(num_output * 0 + q + 3); + const signed char* weight_hc_U_3 = weight_hc_dr.row(num_output * 1 + q + 3); + const signed char* weight_hc_N_3 = weight_hc_dr.row(num_output * 2 + q + 3); + + signed char* kptr = weight_data_tm_dr.row(q / 4); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4); + + int i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_R_0[i]; + kptr[1] = weight_xc_R_0[i + 1]; + kptr[2] = weight_xc_R_0[i + 2]; + kptr[3] = weight_xc_R_0[i + 3]; + kptr[4] = weight_xc_R_1[i]; + kptr[5] = weight_xc_R_1[i + 1]; + kptr[6] = weight_xc_R_1[i + 2]; + kptr[7] = weight_xc_R_1[i + 3]; + kptr[8 + 0] = weight_xc_R_2[i]; + kptr[8 + 1] = weight_xc_R_2[i + 1]; + kptr[8 + 2] = weight_xc_R_2[i + 2]; + kptr[8 + 3] = weight_xc_R_2[i + 3]; + kptr[8 + 4] = weight_xc_R_3[i]; + kptr[8 + 5] = weight_xc_R_3[i + 1]; + kptr[8 + 6] = weight_xc_R_3[i + 2]; + kptr[8 + 7] = weight_xc_R_3[i + 3]; + kptr[16 + 0] = weight_xc_U_0[i]; + kptr[16 + 1] = weight_xc_U_0[i + 1]; + kptr[16 + 2] = weight_xc_U_0[i + 2]; + kptr[16 + 3] = weight_xc_U_0[i + 3]; + kptr[16 + 4] = weight_xc_U_1[i]; + kptr[16 + 5] = weight_xc_U_1[i + 1]; + kptr[16 + 6] = weight_xc_U_1[i + 2]; + kptr[16 + 7] = weight_xc_U_1[i + 3]; + kptr[24 + 0] = weight_xc_U_2[i]; + kptr[24 + 1] = weight_xc_U_2[i + 1]; + kptr[24 + 2] = weight_xc_U_2[i + 2]; + kptr[24 + 3] = weight_xc_U_2[i + 3]; + kptr[24 + 4] = weight_xc_U_3[i]; + kptr[24 + 5] = weight_xc_U_3[i + 1]; + kptr[24 + 6] = weight_xc_U_3[i + 2]; + kptr[24 + 7] = weight_xc_U_3[i + 3]; + + kptr += 32; + } +#else + for (; i + 7 < size; i += 8) + { + int8x8_t _w0 = vld1_s8(weight_xc_R_0 + i); + int8x8_t _w1 = vld1_s8(weight_xc_R_1 + i); + int8x8_t _w2 = vld1_s8(weight_xc_R_2 + i); + int8x8_t _w3 = vld1_s8(weight_xc_R_3 + i); + int8x8_t _w4 = vld1_s8(weight_xc_U_0 + i); + int8x8_t _w5 = vld1_s8(weight_xc_U_1 + i); + int8x8_t _w6 = vld1_s8(weight_xc_U_2 + i); + int8x8_t _w7 = vld1_s8(weight_xc_U_3 + i); + + int32x2x2_t _t0 = vtrn_s32(vreinterpret_s32_s8(_w0), vreinterpret_s32_s8(_w4)); + int32x2x2_t _t1 = vtrn_s32(vreinterpret_s32_s8(_w1), vreinterpret_s32_s8(_w5)); + int32x2x2_t _t2 = vtrn_s32(vreinterpret_s32_s8(_w2), vreinterpret_s32_s8(_w6)); + int32x2x2_t _t3 = vtrn_s32(vreinterpret_s32_s8(_w3), vreinterpret_s32_s8(_w7)); + + int32x4x4_t _w; + _w.val[0] = vcombine_s32(_t0.val[0], _t0.val[1]); + _w.val[1] = vcombine_s32(_t1.val[0], _t1.val[1]); + _w.val[2] = vcombine_s32(_t2.val[0], _t2.val[1]); + _w.val[3] = vcombine_s32(_t3.val[0], _t3.val[1]); + + vst4q_s32((int*)kptr, _w); + + kptr += 64; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_R_0[i]; + kptr[1] = weight_xc_R_0[i + 1]; + kptr[2] = weight_xc_R_1[i]; + kptr[3] = weight_xc_R_1[i + 1]; + kptr[4] = weight_xc_R_2[i]; + kptr[5] = weight_xc_R_2[i + 1]; + kptr[6] = weight_xc_R_3[i]; + kptr[7] = weight_xc_R_3[i + 1]; + kptr[8 + 0] = weight_xc_U_0[i]; + kptr[8 + 1] = weight_xc_U_0[i + 1]; + kptr[8 + 2] = weight_xc_U_1[i]; + kptr[8 + 3] = weight_xc_U_1[i + 1]; + kptr[8 + 4] = weight_xc_U_2[i]; + kptr[8 + 5] = weight_xc_U_2[i + 1]; + kptr[8 + 6] = weight_xc_U_3[i]; + kptr[8 + 7] = weight_xc_U_3[i + 1]; + + kptr += 16; + } + for (; i < size; i++) + { + kptr[0] = weight_xc_R_0[i]; + kptr[1] = weight_xc_R_1[i]; + kptr[2] = weight_xc_R_2[i]; + kptr[3] = weight_xc_R_3[i]; + kptr[4] = weight_xc_U_0[i]; + kptr[5] = weight_xc_U_1[i]; + kptr[6] = weight_xc_U_2[i]; + kptr[7] = weight_xc_U_3[i]; + + kptr += 8; + } + + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_R_0[i]; + kptr[1] = weight_hc_R_0[i + 1]; + kptr[2] = weight_hc_R_0[i + 2]; + kptr[3] = weight_hc_R_0[i + 3]; + kptr[4] = weight_hc_R_1[i]; + kptr[5] = weight_hc_R_1[i + 1]; + kptr[6] = weight_hc_R_1[i + 2]; + kptr[7] = weight_hc_R_1[i + 3]; + kptr[8 + 0] = weight_hc_R_2[i]; + kptr[8 + 1] = weight_hc_R_2[i + 1]; + kptr[8 + 2] = weight_hc_R_2[i + 2]; + kptr[8 + 3] = weight_hc_R_2[i + 3]; + kptr[8 + 4] = weight_hc_R_3[i]; + kptr[8 + 5] = weight_hc_R_3[i + 1]; + kptr[8 + 6] = weight_hc_R_3[i + 2]; + kptr[8 + 7] = weight_hc_R_3[i + 3]; + kptr[16 + 0] = weight_hc_U_0[i]; + kptr[16 + 1] = weight_hc_U_0[i + 1]; + kptr[16 + 2] = weight_hc_U_0[i + 2]; + kptr[16 + 3] = weight_hc_U_0[i + 3]; + kptr[16 + 4] = weight_hc_U_1[i]; + kptr[16 + 5] = weight_hc_U_1[i + 1]; + kptr[16 + 6] = weight_hc_U_1[i + 2]; + kptr[16 + 7] = weight_hc_U_1[i + 3]; + kptr[24 + 0] = weight_hc_U_2[i]; + kptr[24 + 1] = weight_hc_U_2[i + 1]; + kptr[24 + 2] = weight_hc_U_2[i + 2]; + kptr[24 + 3] = weight_hc_U_2[i + 3]; + kptr[24 + 4] = weight_hc_U_3[i]; + kptr[24 + 5] = weight_hc_U_3[i + 1]; + kptr[24 + 6] = weight_hc_U_3[i + 2]; + kptr[24 + 7] = weight_hc_U_3[i + 3]; + + kptr += 32; + } +#else + for (; i + 7 < num_output; i += 8) + { + int8x8_t _w0 = vld1_s8(weight_hc_R_0 + i); + int8x8_t _w1 = vld1_s8(weight_hc_R_1 + i); + int8x8_t _w2 = vld1_s8(weight_hc_R_2 + i); + int8x8_t _w3 = vld1_s8(weight_hc_R_3 + i); + int8x8_t _w4 = vld1_s8(weight_hc_U_0 + i); + int8x8_t _w5 = vld1_s8(weight_hc_U_1 + i); + int8x8_t _w6 = vld1_s8(weight_hc_U_2 + i); + int8x8_t _w7 = vld1_s8(weight_hc_U_3 + i); + + int32x2x2_t _t0 = vtrn_s32(vreinterpret_s32_s8(_w0), vreinterpret_s32_s8(_w4)); + int32x2x2_t _t1 = vtrn_s32(vreinterpret_s32_s8(_w1), vreinterpret_s32_s8(_w5)); + int32x2x2_t _t2 = vtrn_s32(vreinterpret_s32_s8(_w2), vreinterpret_s32_s8(_w6)); + int32x2x2_t _t3 = vtrn_s32(vreinterpret_s32_s8(_w3), vreinterpret_s32_s8(_w7)); + + int32x4x4_t _w; + _w.val[0] = vcombine_s32(_t0.val[0], _t0.val[1]); + _w.val[1] = vcombine_s32(_t1.val[0], _t1.val[1]); + _w.val[2] = vcombine_s32(_t2.val[0], _t2.val[1]); + _w.val[3] = vcombine_s32(_t3.val[0], _t3.val[1]); + + vst4q_s32((int*)kptr, _w); + + kptr += 64; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_R_0[i]; + kptr[1] = weight_hc_R_0[i + 1]; + kptr[2] = weight_hc_R_1[i]; + kptr[3] = weight_hc_R_1[i + 1]; + kptr[4] = weight_hc_R_2[i]; + kptr[5] = weight_hc_R_2[i + 1]; + kptr[6] = weight_hc_R_3[i]; + kptr[7] = weight_hc_R_3[i + 1]; + kptr[8 + 0] = weight_hc_U_0[i]; + kptr[8 + 1] = weight_hc_U_0[i + 1]; + kptr[8 + 2] = weight_hc_U_1[i]; + kptr[8 + 3] = weight_hc_U_1[i + 1]; + kptr[8 + 4] = weight_hc_U_2[i]; + kptr[8 + 5] = weight_hc_U_2[i + 1]; + kptr[8 + 6] = weight_hc_U_3[i]; + kptr[8 + 7] = weight_hc_U_3[i + 1]; + + kptr += 16; + } + for (; i < num_output; i++) + { + kptr[0] = weight_hc_R_0[i]; + kptr[1] = weight_hc_R_1[i]; + kptr[2] = weight_hc_R_2[i]; + kptr[3] = weight_hc_R_3[i]; + kptr[4] = weight_hc_U_0[i]; + kptr[5] = weight_hc_U_1[i]; + kptr[6] = weight_hc_U_2[i]; + kptr[7] = weight_hc_U_3[i]; + + kptr += 8; + } + + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_N_0[i]; + kptr[1] = weight_hc_N_0[i + 1]; + kptr[2] = weight_hc_N_0[i + 2]; + kptr[3] = weight_hc_N_0[i + 3]; + kptr[4] = weight_hc_N_1[i]; + kptr[5] = weight_hc_N_1[i + 1]; + kptr[6] = weight_hc_N_1[i + 2]; + kptr[7] = weight_hc_N_1[i + 3]; + kptr[8 + 0] = weight_hc_N_2[i]; + kptr[8 + 1] = weight_hc_N_2[i + 1]; + kptr[8 + 2] = weight_hc_N_2[i + 2]; + kptr[8 + 3] = weight_hc_N_2[i + 3]; + kptr[8 + 4] = weight_hc_N_3[i]; + kptr[8 + 5] = weight_hc_N_3[i + 1]; + kptr[8 + 6] = weight_hc_N_3[i + 2]; + kptr[8 + 7] = weight_hc_N_3[i + 3]; + + kptr += 16; + } +#else + for (; i + 7 < num_output; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_hc_N_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_hc_N_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_hc_N_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_hc_N_3 + i)); + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_N_0[i]; + kptr[1] = weight_hc_N_0[i + 1]; + kptr[2] = weight_hc_N_1[i]; + kptr[3] = weight_hc_N_1[i + 1]; + kptr[4] = weight_hc_N_2[i]; + kptr[5] = weight_hc_N_2[i + 1]; + kptr[6] = weight_hc_N_3[i]; + kptr[7] = weight_hc_N_3[i + 1]; + + kptr += 8; + } + for (; i < num_output; i++) + { + kptr[0] = weight_hc_N_0[i]; + kptr[1] = weight_hc_N_1[i]; + kptr[2] = weight_hc_N_2[i]; + kptr[3] = weight_hc_N_3[i]; + + kptr += 4; + } + + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_N_0[i]; + kptr[1] = weight_xc_N_0[i + 1]; + kptr[2] = weight_xc_N_0[i + 2]; + kptr[3] = weight_xc_N_0[i + 3]; + kptr[4] = weight_xc_N_1[i]; + kptr[5] = weight_xc_N_1[i + 1]; + kptr[6] = weight_xc_N_1[i + 2]; + kptr[7] = weight_xc_N_1[i + 3]; + kptr[8 + 0] = weight_xc_N_2[i]; + kptr[8 + 1] = weight_xc_N_2[i + 1]; + kptr[8 + 2] = weight_xc_N_2[i + 2]; + kptr[8 + 3] = weight_xc_N_2[i + 3]; + kptr[8 + 4] = weight_xc_N_3[i]; + kptr[8 + 5] = weight_xc_N_3[i + 1]; + kptr[8 + 6] = weight_xc_N_3[i + 2]; + kptr[8 + 7] = weight_xc_N_3[i + 3]; + + kptr += 16; + } +#else + for (; i + 7 < size; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_xc_N_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_xc_N_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_xc_N_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_xc_N_3 + i)); + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_N_0[i]; + kptr[1] = weight_xc_N_0[i + 1]; + kptr[2] = weight_xc_N_1[i]; + kptr[3] = weight_xc_N_1[i + 1]; + kptr[4] = weight_xc_N_2[i]; + kptr[5] = weight_xc_N_2[i + 1]; + kptr[6] = weight_xc_N_3[i]; + kptr[7] = weight_xc_N_3[i + 1]; + + kptr += 8; + } + for (; i < size; i++) + { + kptr[0] = weight_xc_N_0[i]; + kptr[1] = weight_xc_N_1[i]; + kptr[2] = weight_xc_N_2[i]; + kptr[3] = weight_xc_N_3[i]; + + kptr += 4; + } + + float32x4_t _xc_R0 = vld1q_f32(weight_xc_int8_scales_ptr + q); + float32x4_t _xc_U0 = vld1q_f32(weight_xc_int8_scales_ptr + num_output + q); + float32x4_t _xc_N0 = vld1q_f32(weight_xc_int8_scales_ptr + num_output * 2 + q); + float32x4_t _hc_R0 = vld1q_f32(weight_hc_int8_scales_ptr + q); + float32x4_t _hc_U0 = vld1q_f32(weight_hc_int8_scales_ptr + num_output + q); + float32x4_t _hc_N0 = vld1q_f32(weight_hc_int8_scales_ptr + num_output * 2 + q); + +#if __aarch64__ + float32x4_t _one = vdupq_n_f32(1.f); + float32x4_t _reciprocal_xc_R0 = vdivq_f32(_one, _xc_R0); + float32x4_t _reciprocal_xc_U0 = vdivq_f32(_one, _xc_U0); + float32x4_t _reciprocal_xc_N0 = vdivq_f32(_one, _xc_N0); + float32x4_t _reciprocal_hc_R0 = vdivq_f32(_one, _hc_R0); + float32x4_t _reciprocal_hc_U0 = vdivq_f32(_one, _hc_U0); + float32x4_t _reciprocal_hc_N0 = vdivq_f32(_one, _hc_N0); +#else + float32x4_t _reciprocal_xc_R0 = vrecpeq_f32(_xc_R0); + float32x4_t _reciprocal_xc_U0 = vrecpeq_f32(_xc_U0); + float32x4_t _reciprocal_xc_N0 = vrecpeq_f32(_xc_N0); + _reciprocal_xc_R0 = vmulq_f32(vrecpsq_f32(_xc_R0, _reciprocal_xc_R0), _reciprocal_xc_R0); + _reciprocal_xc_U0 = vmulq_f32(vrecpsq_f32(_xc_U0, _reciprocal_xc_U0), _reciprocal_xc_U0); + _reciprocal_xc_N0 = vmulq_f32(vrecpsq_f32(_xc_N0, _reciprocal_xc_N0), _reciprocal_xc_N0); + _reciprocal_xc_R0 = vmulq_f32(vrecpsq_f32(_xc_R0, _reciprocal_xc_R0), _reciprocal_xc_R0); + _reciprocal_xc_U0 = vmulq_f32(vrecpsq_f32(_xc_U0, _reciprocal_xc_U0), _reciprocal_xc_U0); + _reciprocal_xc_N0 = vmulq_f32(vrecpsq_f32(_xc_N0, _reciprocal_xc_N0), _reciprocal_xc_N0); + float32x4_t _reciprocal_hc_R0 = vrecpeq_f32(_hc_R0); + float32x4_t _reciprocal_hc_U0 = vrecpeq_f32(_hc_U0); + float32x4_t _reciprocal_hc_N0 = vrecpeq_f32(_hc_N0); + _reciprocal_hc_R0 = vmulq_f32(vrecpsq_f32(_hc_R0, _reciprocal_hc_R0), _reciprocal_hc_R0); + _reciprocal_hc_U0 = vmulq_f32(vrecpsq_f32(_hc_U0, _reciprocal_hc_U0), _reciprocal_hc_U0); + _reciprocal_hc_N0 = vmulq_f32(vrecpsq_f32(_hc_N0, _reciprocal_hc_N0), _reciprocal_hc_N0); + _reciprocal_hc_R0 = vmulq_f32(vrecpsq_f32(_hc_R0, _reciprocal_hc_R0), _reciprocal_hc_R0); + _reciprocal_hc_U0 = vmulq_f32(vrecpsq_f32(_hc_U0, _reciprocal_hc_U0), _reciprocal_hc_U0); + _reciprocal_hc_N0 = vmulq_f32(vrecpsq_f32(_hc_N0, _reciprocal_hc_N0), _reciprocal_hc_N0); +#endif + + vst1q_f32(descales_ptr, _reciprocal_xc_R0); + vst1q_f32(descales_ptr + 4, _reciprocal_xc_U0); + vst1q_f32(descales_ptr + 8, _reciprocal_hc_R0); + vst1q_f32(descales_ptr + 12, _reciprocal_hc_U0); + vst1q_f32(descales_ptr + 16, _reciprocal_hc_N0); + vst1q_f32(descales_ptr + 20, _reciprocal_xc_N0); + } +#endif // __ARM_NEON + for (; q < num_output; q++) + { + bias_c_RUBNWN[0] = bias_c_R[q]; + bias_c_RUBNWN[1] = bias_c_U[q]; + bias_c_RUBNWN[2] = bias_c_BN[q]; + bias_c_RUBNWN[3] = bias_c_WN[q]; + + bias_c_RUBNWN += 4; + + const signed char* weight_xc_R = weight_xc_dr.row(num_output * 0 + q); + const signed char* weight_xc_U = weight_xc_dr.row(num_output * 1 + q); + const signed char* weight_xc_N = weight_xc_dr.row(num_output * 2 + q); + + const signed char* weight_hc_R = weight_hc_dr.row(num_output * 0 + q); + const signed char* weight_hc_U = weight_hc_dr.row(num_output * 1 + q); + const signed char* weight_hc_N = weight_hc_dr.row(num_output * 2 + q); + +#if __ARM_NEON + signed char* kptr = weight_data_tm_dr.row(q / 4 + q % 4); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4 + q % 4); +#else + signed char* kptr = weight_data_tm_dr.row(q); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); +#endif // __ARM_NEON + + for (int i = 0; i < size; i++) + { + kptr[0] = weight_xc_R[i]; + kptr[1] = weight_xc_U[i]; + kptr += 2; + } + + for (int i = 0; i < num_output; i++) + { + kptr[0] = weight_hc_R[i]; + kptr[1] = weight_hc_U[i]; + kptr += 2; + } + + for (int i = 0; i < num_output; i++) + { + kptr[0] = weight_hc_N[i]; + kptr += 1; + } + + for (int i = 0; i < size; i++) + { + kptr[0] = weight_xc_N[i]; + kptr += 1; + } + + descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[num_output * 0 + q]; + descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[num_output * 1 + q]; + descales_ptr[2] = 1.f / weight_hc_int8_scales_ptr[num_output * 0 + q]; + descales_ptr[3] = 1.f / weight_hc_int8_scales_ptr[num_output * 1 + q]; + descales_ptr[4] = 1.f / weight_hc_int8_scales_ptr[num_output * 2 + q]; + descales_ptr[5] = 1.f / weight_xc_int8_scales_ptr[num_output * 2 + q]; + } + } +} + +static void gru_int8_gate_output(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2) + if (ncnn::cpu_support_arm_vfpv4()) + { + gru_int8_gate_output_vfpv4(gates, hidden_state, top_blob, ti, elemtype, opt); + return; + } +#endif + + const int num_output = top_blob.w; + + // h_t := (1 - update) .* new + update .* h_{t-1} + float* output_data = top_blob.row(ti); + + float* hidden_ptr = hidden_state; + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q / 4); + + float32x4_t _gru_U0 = vld1q_f32(gates_data); + float32x4_t _gru_N0 = vld1q_f32(gates_data + 4); + + float32x4_t _gru_H0 = vaddq_f32(vmulq_f32(vsubq_f32(vdupq_n_f32(1.f), _gru_U0), _gru_N0), vmulq_f32(_gru_U0, vld1q_f32(hidden_ptr + q))); + + vst1q_f32(hidden_ptr + q, _gru_H0); + + if (elemtype == 1) + { + // fp32 + vst1q_f32(output_data + q, _gru_H0); + } + if (elemtype == 2) + { + // fp16 + unsigned short* outptr = (unsigned short*)output_data + q; +#if (__ARM_FP & 2) +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "fcvtn v0.4h, %2.4s \n" + "st1 {v0.4h}, [%0] \n" + : "=r"(outptr) // %0 + : "0"(outptr), + "w"(_gru_H0) + : "memory", "v0"); +#else // __aarch64__ + asm volatile( + "vcvt.f16.f32 d0, %q2 \n" + "vst1.u16 {d0}, [%0] \n" + : "=r"(outptr) // %0 + : "0"(outptr), + "w"(_gru_H0) + : "memory", "q0"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + vst1_u16(outptr, (uint16x4_t)vcvt_f16_f32(_gru_H0)); +#endif // NCNN_GNU_INLINE_ASM +#else + outptr[q] = float32_to_float16(hidden_ptr[q]); + outptr[q + 1] = float32_to_float16(hidden_ptr[q + 1]); + outptr[q + 2] = float32_to_float16(hidden_ptr[q + 2]); + outptr[q + 3] = float32_to_float16(hidden_ptr[q + 3]); +#endif // (__ARM_FP & 2) + } + if (elemtype == 4) + { + // bf16 + vst1_u16((unsigned short*)output_data + q, float2bfloat(_gru_H0)); + } + } + remain_num_output_start += nn_num_output << 2; +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { +#if __ARM_NEON + const float* gates_data = gates.row(q / 4 + q % 4); +#else + const float* gates_data = gates.row(q); +#endif + + float U = gates_data[0]; + float N = gates_data[1]; + + float H = (1 - U) * N + U * hidden_ptr[q]; + + hidden_ptr[q] = H; + + if (elemtype == 1) + { + output_data[q] = H; + } + if (elemtype == 2) + { + ((unsigned short*)output_data)[q] = float32_to_float16(H); + } + if (elemtype == 4) + { + ((unsigned short*)output_data)[q] = float32_to_bfloat16(H); + } + } +} + +static void gru_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD + if (ncnn::cpu_support_arm_asimddp()) + { + gru_int8_asimddp(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, hidden_state, opt); + return; + } +#endif + + int size = bottom_blob_int8.w; + int T = bottom_blob_int8.h; + + int num_output = top_blob.w; + + // 2 x num_output +#if __ARM_NEON + Mat gates(4 * 2, num_output / 4 + num_output % 4, 4u, opt.workspace_allocator); +#else + Mat gates(2, num_output, 4u, opt.workspace_allocator); +#endif + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + float hidden_state_int8_scale = 1.f; + float hidden_state_int8_descale = 1.f; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scale = 127.f / absmax; + hidden_state_int8_descale = absmax / 127.f; + + signed char* hs = hidden_state_int8; + for (int i = 0; i < num_output; i++) + { + hs[i] = float2int8(hidden_state[i] * hidden_state_int8_scale); + } + } + } + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + // gate reset update + const float* bias_c_RUBNWN = (const float*)bias_c + q * 4; + + const signed char* kptr = weight_data_tm.row(q / 4); + + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4); + + int32x4_t _gru_Rx0 = vdupq_n_s32(0); + int32x4_t _gru_Ux0 = vdupq_n_s32(0); + int i = 0; +#if __ARM_FEATURE_DOTPROD + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _gru_Rx0 = vdotq_lane_s32(_gru_Rx0, _w0, _xi, 0); + _gru_Ux0 = vdotq_lane_s32(_gru_Ux0, _w1, _xi, 0); + _sum1 = vdotq_lane_s32(_sum1, _w2, _xi, 1); + _sum2 = vdotq_lane_s32(_sum2, _w3, _xi, 1); + + kptr += 64; + } + _gru_Rx0 = vaddq_s32(_gru_Rx0, _sum1); + _gru_Ux0 = vaddq_s32(_gru_Ux0, _sum2); +#else + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 7 < size; i += 8) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16}, [%0] \n" + "vdup.32 d17, d16[0] \n" + "vdup.32 d16, d16[1] \n" + "vmull.s8 q4, d0, d17 \n" + "vmull.s8 q5, d1, d17 \n" + "vmull.s8 q6, d2, d17 \n" + "vmull.s8 q7, d3, d17 \n" + "vmlal.s8 q4, d4, d16 \n" + "vmlal.s8 q5, d5, d16 \n" + "vmlal.s8 q6, d6, d16 \n" + "vmlal.s8 q7, d7, d16 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int32x2_t _xi01 = vreinterpret_s32_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s32(vdup_lane_s32(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s32(vdup_lane_s32(_xi01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi0); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), _xi1); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), _xi1); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), _xi1); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), _xi1); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + { + int32x2_t _s0 = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + int32x2_t _s1 = vpadd_s32(vget_low_s32(_sum1), vget_high_s32(_sum1)); + int32x2_t _s2 = vpadd_s32(vget_low_s32(_sum2), vget_high_s32(_sum2)); + int32x2_t _s3 = vpadd_s32(vget_low_s32(_sum3), vget_high_s32(_sum3)); + _gru_Rx0 = vaddq_s32(_gru_Rx0, vcombine_s32(_s0, _s1)); + _gru_Ux0 = vaddq_s32(_gru_Ux0, vcombine_s32(_s2, _s3)); + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Rx0 = vdotq_lane_s32(_gru_Rx0, _w0, _xi, 0); + _gru_Ux0 = vdotq_lane_s32(_gru_Ux0, _w1, _xi, 0); +#else + int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1)); + int8x16_t _weight_xc_RU0 = vld1q_s8(kptr); + int8x16_t _weight_xc_RU1 = vld1q_s8(kptr + 16); + + int16x8_t _gru_Rx = vmull_s8(vget_low_s8(_weight_xc_RU0), _xi0); + int16x8_t _gru_Ux = vmull_s8(vget_high_s8(_weight_xc_RU0), _xi0); + _gru_Rx = vmlal_s8(_gru_Rx, vget_low_s8(_weight_xc_RU1), _xi1); + _gru_Ux = vmlal_s8(_gru_Ux, vget_high_s8(_weight_xc_RU1), _xi1); + + _gru_Rx0 = vpadalq_s16(_gru_Rx0, _gru_Rx); + _gru_Ux0 = vpadalq_s16(_gru_Ux0, _gru_Ux); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 32; + } + for (; i + 1 < size; i += 2) + { + int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0)); + int8x16_t _weight_xc_RU = vld1q_s8(kptr); + + int16x8_t _gru_Rx = vmull_s8(vget_low_s8(_weight_xc_RU), _xi); + int16x8_t _gru_Ux = vmull_s8(vget_high_s8(_weight_xc_RU), _xi); + + _gru_Rx0 = vpadalq_s16(_gru_Rx0, _gru_Rx); + _gru_Ux0 = vpadalq_s16(_gru_Ux0, _gru_Ux); + + kptr += 16; + } + for (; i < size; i++) + { + int8x8_t _xi = vdup_n_s8(x[i]); + int8x8_t _weight_xc_RU = vld1_s8(kptr); + + int16x8_t _gru_RxUx = vmull_s8(_weight_xc_RU, _xi); + _gru_Rx0 = vaddw_s16(_gru_Rx0, vget_low_s16(_gru_RxUx)); + _gru_Ux0 = vaddw_s16(_gru_Ux0, vget_high_s16(_gru_RxUx)); + + kptr += 8; + } + + int32x4_t _gru_Rh0 = vdupq_n_s32(0); + int32x4_t _gru_Uh0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _gru_Rh0 = vdotq_lane_s32(_gru_Rh0, _w0, _h_cont, 0); + _gru_Uh0 = vdotq_lane_s32(_gru_Uh0, _w1, _h_cont, 0); + _sum1 = vdotq_lane_s32(_sum1, _w2, _h_cont, 1); + _sum2 = vdotq_lane_s32(_sum2, _w3, _h_cont, 1); + + kptr += 64; + } + _gru_Rh0 = vaddq_s32(_gru_Rh0, _sum1); + _gru_Uh0 = vaddq_s32(_gru_Uh0, _sum2); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 7 < num_output; i += 8) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16}, [%0] \n" + "vdup.32 d17, d16[0] \n" + "vdup.32 d16, d16[1] \n" + "vmull.s8 q4, d0, d17 \n" + "vmull.s8 q5, d1, d17 \n" + "vmull.s8 q6, d2, d17 \n" + "vmull.s8 q7, d3, d17 \n" + "vmlal.s8 q4, d4, d16 \n" + "vmlal.s8 q5, d5, d16 \n" + "vmlal.s8 q6, d6, d16 \n" + "vmlal.s8 q7, d7, d16 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int32x2_t _h_cont01 = vreinterpret_s32_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s32(vdup_lane_s32(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s32(vdup_lane_s32(_h_cont01, 1)); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont0); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont0); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont0); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont0); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), _h_cont1); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), _h_cont1); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), _h_cont1); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), _h_cont1); + + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + { + int32x2_t _s0 = vpadd_s32(vget_low_s32(_sum0), vget_high_s32(_sum0)); + int32x2_t _s1 = vpadd_s32(vget_low_s32(_sum1), vget_high_s32(_sum1)); + int32x2_t _s2 = vpadd_s32(vget_low_s32(_sum2), vget_high_s32(_sum2)); + int32x2_t _s3 = vpadd_s32(vget_low_s32(_sum3), vget_high_s32(_sum3)); + _gru_Rh0 = vaddq_s32(_gru_Rh0, vcombine_s32(_s0, _s1)); + _gru_Uh0 = vaddq_s32(_gru_Uh0, vcombine_s32(_s2, _s3)); + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Rh0 = vdotq_lane_s32(_gru_Rh0, _w0, _h_cont, 0); + _gru_Uh0 = vdotq_lane_s32(_gru_Uh0, _w1, _h_cont, 0); +#else + int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1)); + int8x16_t _weight_hc_RU0 = vld1q_s8(kptr); + int8x16_t _weight_hc_RU1 = vld1q_s8(kptr + 16); + + int16x8_t _gru_Rh = vmull_s8(vget_low_s8(_weight_hc_RU0), _h_cont0); + int16x8_t _gru_Uh = vmull_s8(vget_high_s8(_weight_hc_RU0), _h_cont0); + _gru_Rh = vmlal_s8(_gru_Rh, vget_low_s8(_weight_hc_RU1), _h_cont1); + _gru_Uh = vmlal_s8(_gru_Uh, vget_high_s8(_weight_hc_RU1), _h_cont1); + + _gru_Rh0 = vpadalq_s16(_gru_Rh0, _gru_Rh); + _gru_Uh0 = vpadalq_s16(_gru_Uh0, _gru_Uh); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 32; + } + for (; i + 1 < num_output; i += 2) + { + int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0)); + int8x16_t _weight_hc_RU = vld1q_s8(kptr); + + int16x8_t _gru_Rh = vmull_s8(vget_low_s8(_weight_hc_RU), _h_cont); + int16x8_t _gru_Uh = vmull_s8(vget_high_s8(_weight_hc_RU), _h_cont); + + _gru_Rh0 = vpadalq_s16(_gru_Rh0, _gru_Rh); + _gru_Uh0 = vpadalq_s16(_gru_Uh0, _gru_Uh); + + kptr += 16; + } + for (; i < num_output; i++) + { + int8x8_t _h_cont = vdup_n_s8(hs[i]); + int8x8_t _weight_hc_RU = vld1_s8(kptr); + + int16x8_t _gru_RhUh = vmull_s8(_weight_hc_RU, _h_cont); + _gru_Rh0 = vaddw_s16(_gru_Rh0, vget_low_s16(_gru_RhUh)); + _gru_Uh0 = vaddw_s16(_gru_Uh0, vget_high_s16(_gru_RhUh)); + + kptr += 8; + } + + float32x4_t _descale_x = vdupq_n_f32(descale_x); + float32x4_t _descale_h = vdupq_n_f32(descale_h); + + float32x4_t _gru_R0 = vld1q_f32(bias_c_RUBNWN); + float32x4_t _gru_U0 = vld1q_f32(bias_c_RUBNWN + 4); + + float32x4_t _descale_xc_R0 = vld1q_f32(descales_ptr); + float32x4_t _descale_xc_U0 = vld1q_f32(descales_ptr + 4); + + _gru_R0 = vmlaq_f32(_gru_R0, vcvtq_f32_s32(_gru_Rx0), vmulq_f32(_descale_x, _descale_xc_R0)); + _gru_U0 = vmlaq_f32(_gru_U0, vcvtq_f32_s32(_gru_Ux0), vmulq_f32(_descale_x, _descale_xc_U0)); + + float32x4_t _descale_hc_R0 = vld1q_f32(descales_ptr + 8); + float32x4_t _descale_hc_U0 = vld1q_f32(descales_ptr + 12); + + _gru_R0 = vmlaq_f32(_gru_R0, vcvtq_f32_s32(_gru_Rh0), vmulq_f32(_descale_h, _descale_hc_R0)); + _gru_U0 = vmlaq_f32(_gru_U0, vcvtq_f32_s32(_gru_Uh0), vmulq_f32(_descale_h, _descale_hc_U0)); + + // sigmoid(R) + // sigmoid(U) + _gru_R0 = sigmoid_ps(_gru_R0); + _gru_U0 = sigmoid_ps(_gru_U0); + + // gate new + + int32x4_t _gru_Nh0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum1 = vdupq_n_s32(0); + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Nh0 = vdotq_lane_s32(_gru_Nh0, _w0, _h_cont, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _h_cont, 1); + + kptr += 32; + } + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum1); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_h_cont)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_h_cont)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_h_cont)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_h_cont)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_h_cont)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_h_cont)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum0); + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum1); + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum2); + _gru_Nh0 = vaddq_s32(_gru_Nh0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w = vld1q_s8(kptr); + _gru_Nh0 = vdotq_lane_s32(_gru_Nh0, _w, _h_cont, 0); +#else + int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _gru_Nh = vmull_s8(vget_low_s8(_w01), _h_cont0); + _gru_Nh = vmlal_s8(_gru_Nh, vget_high_s8(_w01), _h_cont1); + _gru_Nh0 = vpadalq_s16(_gru_Nh0, _gru_Nh); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < num_output; i += 2) + { + int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nh = vmull_s8(_w, _h_cont); + _gru_Nh0 = vpadalq_s16(_gru_Nh0, _gru_Nh); + + kptr += 8; + } + for (; i < num_output; i++) + { + int8x8_t _h_cont = vdup_n_s8(hs[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nh = vmull_s8(_w, _h_cont); + _gru_Nh0 = vaddw_s16(_gru_Nh0, vget_low_s16(_gru_Nh)); + + kptr += 4; + } + + int32x4_t _gru_Nx0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum1 = vdupq_n_s32(0); + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _gru_Nx0 = vdotq_lane_s32(_gru_Nx0, _w0, _xi, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _xi, 1); + + kptr += 32; + } + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum1); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_xi)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_xi)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_xi)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_xi)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_xi)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_xi)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum0); + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum1); + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum2); + _gru_Nx0 = vaddq_s32(_gru_Nx0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w = vld1q_s8(kptr); + _gru_Nx0 = vdotq_lane_s32(_gru_Nx0, _w, _xi, 0); +#else + int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _gru_Nx = vmull_s8(vget_low_s8(_w01), _xi0); + _gru_Nx = vmlal_s8(_gru_Nx, vget_high_s8(_w01), _xi1); + _gru_Nx0 = vpadalq_s16(_gru_Nx0, _gru_Nx); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < size; i += 2) + { + int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nx = vmull_s8(_w, _xi); + _gru_Nx0 = vpadalq_s16(_gru_Nx0, _gru_Nx); + + kptr += 8; + } + for (; i < size; i++) + { + int8x8_t _xi = vdup_n_s8(x[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _gru_Nx = vmull_s8(_w, _xi); + _gru_Nx0 = vaddw_s16(_gru_Nx0, vget_low_s16(_gru_Nx)); + + kptr += 4; + } + + float32x4_t _gru_N0 = vld1q_f32(bias_c_RUBNWN + 8); + + float32x4_t _descale_hc_N0 = vld1q_f32(descales_ptr + 16); + + _gru_N0 = vmlaq_f32(_gru_N0, vcvtq_f32_s32(_gru_Nh0), vmulq_f32(_descale_h, _descale_hc_N0)); + + _gru_N0 = vmlaq_f32(vld1q_f32(bias_c_RUBNWN + 12), _gru_R0, _gru_N0); + + float32x4_t _descale_xc_N0 = vld1q_f32(descales_ptr + 20); + + _gru_N0 = vmlaq_f32(_gru_N0, vcvtq_f32_s32(_gru_Nx0), vmulq_f32(_descale_x, _descale_xc_N0)); + + // tanh(N) + _gru_N0 = tanh_ps(_gru_N0); + + float* gates_data = gates.row(q / 4); + + vst1q_f32(gates_data, _gru_U0); + vst1q_f32(gates_data + 4, _gru_N0); + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + // gate reset update + const float* bias_c_RUBNWN = (const float*)bias_c + q * 4; + +#if __ARM_NEON + const signed char* kptr = weight_data_tm.row(q / 4 + q % 4); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4 + q % 4); +#else + const signed char* kptr = weight_data_tm.row(q); + const float* descales_ptr = weight_data_tm_int8_descales.row(q); +#endif + + const float descale_xc_R = descales_ptr[0]; + const float descale_xc_U = descales_ptr[1]; + const float descale_hc_R = descales_ptr[2]; + const float descale_hc_U = descales_ptr[3]; + const float descale_hc_N = descales_ptr[4]; + const float descale_xc_N = descales_ptr[5]; + + int Rx = 0; + int Ux = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Rx += kptr[0] * xi; + Ux += kptr[1] * xi; + + kptr += 2; + } + + int Rh = 0; + int Uh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Rh += kptr[0] * h_cont; + Uh += kptr[1] * h_cont; + + kptr += 2; + } + + float R = bias_c_RUBNWN[0] + Rx * (descale_x * descale_xc_R) + Rh * (descale_h * descale_hc_R); + float U = bias_c_RUBNWN[1] + Ux * (descale_x * descale_xc_U) + Uh * (descale_h * descale_hc_U); + + // sigmoid(R) + // sigmoid(U) + R = 1.f / (1.f + expf(-R)); + U = 1.f / (1.f + expf(-U)); + + // gate new + + int Nh = 0; + for (int i = 0; i < num_output; i++) + { + Nh += kptr[0] * hs[i]; + kptr += 1; + } + + int Nx = 0; + for (int i = 0; i < size; i++) + { + Nx += kptr[0] * x[i]; + kptr += 1; + } + + float N = bias_c_RUBNWN[2] + Nh * (descale_h * descale_hc_N); + N = bias_c_RUBNWN[3] + R * N + Nx * (descale_x * descale_xc_N); + + // tanh(N) + N = tanhf(N); + +#if __ARM_NEON + float* gates_data = gates.row(q / 4 + q % 4); +#else + float* gates_data = gates.row(q); +#endif + + gates_data[0] = U; + gates_data[1] = N; + } + + gru_int8_gate_output(gates, hidden_state, top_blob, ti, elemtype, opt); + } +} diff --git a/src/layer/arm/lstm_arm.cpp b/src/layer/arm/lstm_arm.cpp index 04d7277547e..a6db35ad23a 100644 --- a/src/layer/arm/lstm_arm.cpp +++ b/src/layer/arm/lstm_arm.cpp @@ -25,6 +25,8 @@ namespace ncnn { +#include "lstm_int8.h" + LSTM_arm::LSTM_arm() { #if __ARM_NEON @@ -40,6 +42,13 @@ LSTM_arm::LSTM_arm() int LSTM_arm::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage) { @@ -406,16 +415,18 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blob, top_blob, opt); + } +#endif + int elembits = bottom_blob.elembits(); #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blob, top_blob, opt); - else - return forward_fp16s(bottom_blob, top_blob, opt); - } + return forward_fp16s(bottom_blob, top_blob, opt); #endif #if NCNN_BF16 @@ -460,16 +471,20 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -488,17 +503,19 @@ int LSTM_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& bottom_blob = bottom_blobs[0]; int elembits = bottom_blob.elembits(); #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blobs, top_blobs, opt); - else - return forward_fp16s(bottom_blobs, top_blobs, opt); - } + return forward_fp16s(bottom_blobs, top_blobs, opt); #endif #if NCNN_BF16 @@ -555,15 +572,19 @@ int LSTM_arm::forward(const std::vector& bottom_blobs, std::vector& to Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret0 != 0) - return ret0; + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret1 != 0) - return ret1; + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -980,16 +1001,20 @@ int LSTM_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; + { + int ret = lstm_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); cell.fill(0.f); - int ret1 = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; + { + int ret = lstm_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -1060,15 +1085,19 @@ int LSTM_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(float16_to_float32(x[i]))); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } + if (elemtype == 4) + { + // bf16 + for (int t = 0; t < T; t++) + { + const unsigned short* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(x[i]))); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } + + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt); +} + +int LSTM_arm::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); + + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 + } + + // *INDENT-ON* + // clang-format on + } + + int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; + + int num_directions = direction == 2 ? 2 : 1; + + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + Mat cell(hidden_size, 4u, opt.workspace_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + + top_blob.create(num_output * num_directions, T, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, elemtype, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, elemtype, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + } + + hidden.fill(0.f); + cell.fill(0.0f); + + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, elemtype, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + } + + // concat w + for (int i = 0; i < T; i++) + { + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); + } + } + + return 0; +} + +int LSTM_arm::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); + + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 + } + + // *INDENT-ON* + // clang-format on + } + + int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + if (elemtype == 1) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + if (elemtype == 2) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_cell_allocator; + cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); + cast_float16_to_float32(bottom_blobs[2], cell, opt_cast); + } + if (elemtype == 4) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_cell_allocator; + cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt_cast); + cast_bfloat16_to_float32(bottom_blobs[2], cell, opt_cast); + } + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, elemtype, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, elemtype, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + } + + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, elemtype, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + } + + // concat w + for (int i = 0; i < T; i++) + { + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); + } + } + + if (top_blobs.size() == 3) + { + if (elemtype == 1) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } + if (elemtype == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + cast_float32_to_float16(cell, top_blobs[2], opt); + } + if (elemtype == 4) + { + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + cast_float32_to_bfloat16(cell, top_blobs[2], opt); + } + } + + return 0; +} +#endif // NCNN_INT8 + } // namespace ncnn diff --git a/src/layer/arm/lstm_arm.h b/src/layer/arm/lstm_arm.h index b5ee1092a52..cd6546103ca 100644 --- a/src/layer/arm/lstm_arm.h +++ b/src/layer/arm/lstm_arm.h @@ -33,19 +33,29 @@ class LSTM_arm : public LSTM int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - int forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; - int forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif #if NCNN_BF16 int create_pipeline_bf16s(const Option& opt); int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + void dynamic_quantize(const Mat& bottom_blob, int elemtype, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const; + int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif public: Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + + Mat weight_data_tm; + +#if NCNN_INT8 + Mat weight_data_tm_int8_descales; +#endif }; } // namespace ncnn diff --git a/src/layer/arm/lstm_arm_asimddp.cpp b/src/layer/arm/lstm_arm_asimddp.cpp new file mode 100644 index 00000000000..966dafabb4c --- /dev/null +++ b/src/layer/arm/lstm_arm_asimddp.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "arm_activation.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "lstm_int8.h" + +void lstm_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt) +{ + lstm_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); +} + +void lstm_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/arm/lstm_arm_asimdhp.cpp b/src/layer/arm/lstm_arm_asimdhp.cpp index 1d3fc71cdfc..f9d96d21cf8 100644 --- a/src/layer/arm/lstm_arm_asimdhp.cpp +++ b/src/layer/arm/lstm_arm_asimdhp.cpp @@ -23,226 +23,6 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) -{ - int size = bottom_blob.w; - int T = bottom_blob.h; - - int num_output = top_blob.w; - int hidden_size = cell_state.w; - - // 4 x hidden_size - Mat gates(4, hidden_size, 4u, opt.workspace_allocator); - if (gates.empty()) - return -100; - - Mat tmp_hidden_state; - if (num_output != hidden_size) - { - tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); - if (tmp_hidden_state.empty()) - return -100; - } - - // unroll - for (int t = 0; t < T; t++) - { - // clip hidden by continuation indicator - // h_cont_{t-1} = cont_t * h_{t-1} - // h_cont_{t-1} = h_{t-1} if cont_t == 1 - // 0 otherwise - // calculate hidden - // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c - - int ti = reverse ? T - 1 - t : t; - - const __fp16* x = bottom_blob.row(ti); - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = 0; q < hidden_size; q++) - { - const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; - - // gate I F O G - const __fp16* weight_xc_IFOG = weight_xc.row(q); - - const __fp16* weight_hc_IFOG = weight_hc.row(q); - - float32x4_t _IFOG = vcvt_f32_f16(vld1_f16(bias_c_IFOG)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - - int i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); - - float32x4_t _weight_xc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG)); - float32x4_t _weight_xc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 4)); - float32x4_t _weight_xc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 8)); - float32x4_t _weight_xc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 12)); - - _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3); - - weight_xc_IFOG += 16; - } - for (; i < size; i++) - { - __fp16 xi = x[i]; - - float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); - float32x4_t _weight_xc_IFOG = vcvt_f32_f16(vld1_f16(weight_xc_IFOG)); - _IFOG = vfmaq_f32(_IFOG, _weight_xc_IFOG, _xi); - - weight_xc_IFOG += 4; - } - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); - - float32x4_t _weight_hc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG)); - float32x4_t _weight_hc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 4)); - float32x4_t _weight_hc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 8)); - float32x4_t _weight_hc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 12)); - - _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3); - - weight_hc_IFOG += 16; - } - for (; i < num_output; i++) - { - float h_cont = hidden_state[i]; - - float32x4_t _h_cont = vdupq_n_f32(h_cont); - float32x4_t _weight_hc_IFOG = vcvt_f32_f16(vld1_f16(weight_hc_IFOG)); - _IFOG = vfmaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); - - weight_hc_IFOG += 4; - } - - float* gates_data = gates.row(q); - - _IFOG = vaddq_f32(_IFOG, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _IFOG = vaddq_f32(_IFOG, _sum2); - - vst1q_f32(gates_data, _IFOG); - } - - // lstm unit - // sigmoid(I) - // sigmoid(F) - // sigmoid(O) - // tanh(G) - // c_t := f_t .* c_{t-1} + i_t .* g_t - // h_t := o_t .* tanh[c_t] - __fp16* output_data = top_blob.row<__fp16>(ti); - - float* cell_ptr = cell_state; - float* hidden_ptr = hidden_state; - float* tmp_hidden_ptr = tmp_hidden_state; - - int nn_hidden_size = hidden_size >> 2; - int remain_hidden_size_start = nn_hidden_size << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_hidden_size; qq++) - { - int q = qq * 4; - - const float* gates_data = gates.row(q); - - float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); - - float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); - float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); - float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); - float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); - - float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); - float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); - - vst1q_f32(cell_ptr + q, _cell2); - - if (num_output == hidden_size) - { - vst1q_f32(hidden_ptr + q, _lstm_H); - vst1_f16(output_data + q, vcvt_f16_f32(_lstm_H)); - } - else - { - vst1q_f32(tmp_hidden_ptr + q, _lstm_H); - } - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_hidden_size_start; q < hidden_size; q++) - { - const float* gates_data = gates.row(q); - - float I = gates_data[0]; - float F = gates_data[1]; - float O = gates_data[2]; - float G = gates_data[3]; - - I = 1.f / (1.f + expf(-I)); - F = 1.f / (1.f + expf(-F)); - O = 1.f / (1.f + expf(-O)); - G = tanhf(G); - - float cell2 = F * cell_ptr[q] + I * G; - float H = O * tanhf(cell2); - - cell_ptr[q] = cell2; - if (num_output == hidden_size) - { - hidden_ptr[q] = H; - output_data[q] = (__fp16)H; - } - else - { - tmp_hidden_ptr[q] = H; - } - } - - if (num_output != hidden_size) - { - // int nn_num_output = num_output >> 2; - // int remain_num_output_start = nn_num_output << 2; - // #pragma omp parallel for num_threads(opt.num_threads) - // for (int qq = 0; qq < nn_num_output; qq++) - // { - // int q = qq * 4; - // - // } - int remain_num_output_start = 0; - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const float* hr = weight_hr.row(q); - const float* tmp_hidden_ptr = tmp_hidden_state; - - float H = 0; - for (int i = 0; i < hidden_size; i++) - { - H += tmp_hidden_ptr[i] * hr[i]; - } - - hidden_ptr[q] = H; - output_data[q] = (__fp16)H; - } - } - } - - return 0; -} - static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { int size = bottom_blob.w; @@ -643,29 +423,252 @@ static int lstm_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const return 0; } -int LSTM_arm::create_pipeline_fp16s(const Option& opt) +static int lstm_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) { - // pack IFOG - int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / hidden_size / 4; - if (opt.use_fp16_arithmetic) + return lstm_fp16sa(bottom_blob, top_blob, reverse, weight_xc, bias_c, weight_hc, weight_hr, hidden_state, cell_state, opt); + + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) { - weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); - bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); - weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); - } - else - { - weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4); - bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); - weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4); + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; } - #pragma omp parallel for num_threads(opt.num_threads) - for (int dr = 0; dr < num_directions; dr++) + // unroll + for (int t = 0; t < T; t++) { - const Mat weight_xc = weight_xc_data.channel(dr); + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + + const __fp16* x = bottom_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const __fp16* bias_c_IFOG = (const __fp16*)bias_c + q * 4; + + // gate I F O G + const __fp16* weight_xc_IFOG = weight_xc.row(q); + + const __fp16* weight_hc_IFOG = weight_hc.row(q); + + float32x4_t _IFOG = vcvt_f32_f16(vld1_f16(bias_c_IFOG)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float32x4_t _xi = vcvt_f32_f16(vld1_f16(x + i)); + + float32x4_t _weight_xc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG)); + float32x4_t _weight_xc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 4)); + float32x4_t _weight_xc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 8)); + float32x4_t _weight_xc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_xc_IFOG + 12)); + + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_xc_IFOG_0, _xi, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_IFOG_1, _xi, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_IFOG_2, _xi, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_IFOG_3, _xi, 3); + + weight_xc_IFOG += 16; + } + for (; i < size; i++) + { + __fp16 xi = x[i]; + + float32x4_t _xi = vcvt_f32_f16(vdup_n_f16(xi)); + float32x4_t _weight_xc_IFOG = vcvt_f32_f16(vld1_f16(weight_xc_IFOG)); + _IFOG = vfmaq_f32(_IFOG, _weight_xc_IFOG, _xi); + + weight_xc_IFOG += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float32x4_t _h_cont = vld1q_f32((const float*)hidden_state + i); + + float32x4_t _weight_hc_IFOG_0 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG)); + float32x4_t _weight_hc_IFOG_1 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 4)); + float32x4_t _weight_hc_IFOG_2 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 8)); + float32x4_t _weight_hc_IFOG_3 = vcvt_f32_f16(vld1_f16(weight_hc_IFOG + 12)); + + _IFOG = vfmaq_laneq_f32(_IFOG, _weight_hc_IFOG_0, _h_cont, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_IFOG_1, _h_cont, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_IFOG_2, _h_cont, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_IFOG_3, _h_cont, 3); + + weight_hc_IFOG += 16; + } + for (; i < num_output; i++) + { + float h_cont = hidden_state[i]; + + float32x4_t _h_cont = vdupq_n_f32(h_cont); + float32x4_t _weight_hc_IFOG = vcvt_f32_f16(vld1_f16(weight_hc_IFOG)); + _IFOG = vfmaq_f32(_IFOG, _weight_hc_IFOG, _h_cont); + + weight_hc_IFOG += 4; + } + + float* gates_data = gates.row(q); + + _IFOG = vaddq_f32(_IFOG, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _IFOG = vaddq_f32(_IFOG, _sum2); + + vst1q_f32(gates_data, _IFOG); + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + __fp16* output_data = top_blob.row<__fp16>(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + int nn_hidden_size = hidden_size >> 2; + int remain_hidden_size_start = nn_hidden_size << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q); + + float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); + + float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); + float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); + float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); + float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); + + float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); + float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); + + vst1q_f32(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _lstm_H); + vst1_f16(output_data + q, vcvt_f16_f32(_lstm_H)); + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _lstm_H); + } + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + } + } + + return 0; +} + +int LSTM_arm::create_pipeline_fp16s(const Option& opt) +{ + // pack IFOG + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; + + if (opt.use_fp16_arithmetic) + { + weight_xc_data_packed.create(size, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); + bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output, hidden_size / 2 + hidden_size % 2, num_directions, 16u, 8); + } + else + { + weight_xc_data_packed.create(size, hidden_size, num_directions, 8u, 4); + bias_c_data_packed.create(hidden_size, 1, num_directions, 8u, 4); + weight_hc_data_packed.create(num_output, hidden_size, num_directions, 8u, 4); + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc = weight_xc_data.channel(dr); const Mat bias_c = bias_c_data.channel(dr); const Mat weight_hc = weight_hc_data.channel(dr); @@ -680,9 +683,9 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) __fp16* bias_c_IFOG = bias_c_data_packed_dr.row<__fp16>(0); + int q = 0; if (opt.use_fp16_arithmetic) { - int q = 0; for (; q + 1 < hidden_size; q += 2) { bias_c_IFOG[0] = (__fp16)bias_c_I[q]; @@ -745,92 +748,48 @@ int LSTM_arm::create_pipeline_fp16s(const Option& opt) weight_hc_IFOG += 8; } } - for (; q < hidden_size; q++) - { - bias_c_IFOG[0] = (__fp16)bias_c_I[q]; - bias_c_IFOG[1] = (__fp16)bias_c_F[q]; - bias_c_IFOG[2] = (__fp16)bias_c_O[q]; - bias_c_IFOG[3] = (__fp16)bias_c_G[q]; - - bias_c_IFOG += 4; - - const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); - const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); - const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); - const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - - const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); - const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); - const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); - const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); - - __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q / 2 + q % 2); - __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q / 2 + q % 2); - - for (int i = 0; i < size; i++) - { - weight_xc_IFOG[0] = (__fp16)weight_xc_I[i]; - weight_xc_IFOG[1] = (__fp16)weight_xc_F[i]; - weight_xc_IFOG[2] = (__fp16)weight_xc_O[i]; - weight_xc_IFOG[3] = (__fp16)weight_xc_G[i]; - - weight_xc_IFOG += 4; - } - - for (int i = 0; i < num_output; i++) - { - weight_hc_IFOG[0] = (__fp16)weight_hc_I[i]; - weight_hc_IFOG[1] = (__fp16)weight_hc_F[i]; - weight_hc_IFOG[2] = (__fp16)weight_hc_O[i]; - weight_hc_IFOG[3] = (__fp16)weight_hc_G[i]; - - weight_hc_IFOG += 4; - } - } } - else + for (; q < hidden_size; q++) { - for (int q = 0; q < hidden_size; q++) - { - bias_c_IFOG[0] = (__fp16)bias_c_I[q]; - bias_c_IFOG[1] = (__fp16)bias_c_F[q]; - bias_c_IFOG[2] = (__fp16)bias_c_O[q]; - bias_c_IFOG[3] = (__fp16)bias_c_G[q]; + bias_c_IFOG[0] = (__fp16)bias_c_I[q]; + bias_c_IFOG[1] = (__fp16)bias_c_F[q]; + bias_c_IFOG[2] = (__fp16)bias_c_O[q]; + bias_c_IFOG[3] = (__fp16)bias_c_G[q]; - bias_c_IFOG += 4; + bias_c_IFOG += 4; - const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); - const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); - const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); - const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); + const float* weight_xc_I = weight_xc.row(hidden_size * 0 + q); + const float* weight_xc_F = weight_xc.row(hidden_size * 1 + q); + const float* weight_xc_O = weight_xc.row(hidden_size * 2 + q); + const float* weight_xc_G = weight_xc.row(hidden_size * 3 + q); - const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); - const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); - const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); - const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); + const float* weight_hc_I = weight_hc.row(hidden_size * 0 + q); + const float* weight_hc_F = weight_hc.row(hidden_size * 1 + q); + const float* weight_hc_O = weight_hc.row(hidden_size * 2 + q); + const float* weight_hc_G = weight_hc.row(hidden_size * 3 + q); - __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(q); - __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(q); + const int qq = opt.use_fp16_arithmetic ? q / 2 + q % 2 : q; + __fp16* weight_xc_IFOG = weight_xc_data_packed_dr.row<__fp16>(qq); + __fp16* weight_hc_IFOG = weight_hc_data_packed_dr.row<__fp16>(qq); - for (int i = 0; i < size; i++) - { - weight_xc_IFOG[0] = (__fp16)weight_xc_I[i]; - weight_xc_IFOG[1] = (__fp16)weight_xc_F[i]; - weight_xc_IFOG[2] = (__fp16)weight_xc_O[i]; - weight_xc_IFOG[3] = (__fp16)weight_xc_G[i]; + for (int i = 0; i < size; i++) + { + weight_xc_IFOG[0] = (__fp16)weight_xc_I[i]; + weight_xc_IFOG[1] = (__fp16)weight_xc_F[i]; + weight_xc_IFOG[2] = (__fp16)weight_xc_O[i]; + weight_xc_IFOG[3] = (__fp16)weight_xc_G[i]; - weight_xc_IFOG += 4; - } + weight_xc_IFOG += 4; + } - for (int i = 0; i < num_output; i++) - { - weight_hc_IFOG[0] = (__fp16)weight_hc_I[i]; - weight_hc_IFOG[1] = (__fp16)weight_hc_F[i]; - weight_hc_IFOG[2] = (__fp16)weight_hc_O[i]; - weight_hc_IFOG[3] = (__fp16)weight_hc_G[i]; + for (int i = 0; i < num_output; i++) + { + weight_hc_IFOG[0] = (__fp16)weight_hc_I[i]; + weight_hc_IFOG[1] = (__fp16)weight_hc_F[i]; + weight_hc_IFOG[2] = (__fp16)weight_hc_O[i]; + weight_hc_IFOG[3] = (__fp16)weight_hc_G[i]; - weight_hc_IFOG += 4; - } + weight_hc_IFOG += 4; } } } @@ -884,166 +843,20 @@ int LSTM_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; - - hidden.fill(0.f); - cell.fill(0.f); - - int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; - - // concat w - for (int i = 0; i < T; i++) { - const __fp16* pf = top_blob_forward.row(i); - const __fp16* pr = top_blob_reverse.row(i); - __fp16* ptr = top_blob.row<__fp16>(i); - - memcpy(ptr, pf, num_output * sizeof(__fp16)); - memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + int ret = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; } - } - - return 0; -} - -int LSTM_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const -{ - const Mat& bottom_blob = bottom_blobs[0]; - int T = bottom_blob.h; - int num_directions = direction == 2 ? 2 : 1; - Mat hidden; - Mat cell; - Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; - if (bottom_blobs.size() == 3) - { - Option opt_cast = opt; - opt_cast.blob_allocator = hidden_cell_allocator; - cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); - cast_float16_to_float32(bottom_blobs[2], cell, opt_cast); - } - else - { - hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); - if (hidden.empty()) - return -100; hidden.fill(0.f); - - cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); - if (cell.empty()) - return -100; cell.fill(0.f); - } - Mat& top_blob = top_blobs[0]; - top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = lstm_fp16s(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - Mat hidden0 = hidden.row_range(0, 1); - Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret0 != 0) - return ret0; - - Mat hidden1 = hidden.row_range(1, 1); - Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret1 != 0) - return ret1; - - // concat w - for (int i = 0; i < T; i++) { - const __fp16* pf = top_blob_forward.row(i); - const __fp16* pr = top_blob_reverse.row(i); - __fp16* ptr = top_blob.row<__fp16>(i); - - memcpy(ptr, pf, num_output * sizeof(__fp16)); - memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + int ret = lstm_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; } - } - - if (top_blobs.size() == 3) - { - cast_float32_to_float16(hidden, top_blobs[1], opt); - cast_float32_to_float16(cell, top_blobs[2], opt); - } - - return 0; -} - -int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const -{ - int T = bottom_blob.h; - - int num_directions = direction == 2 ? 2 : 1; - - // initial hidden state - Mat hidden(num_output, 4u, opt.workspace_allocator); - if (hidden.empty()) - return -100; - hidden.fill(0.f); - - Mat cell(hidden_size, 4u, opt.workspace_allocator); - if (cell.empty()) - return -100; - cell.fill(0.f); - - top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = lstm_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - int ret0 = lstm_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; - - hidden.fill(0.f); - cell.fill(0.f); - - int ret1 = lstm_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; // concat w for (int i = 0; i < T; i++) @@ -1060,7 +873,7 @@ int LSTM_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option return 0; } -int LSTM_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +int LSTM_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; @@ -1097,7 +910,7 @@ int LSTM_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc_dr.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc_dr.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc_dr.row(hidden_size * 3 + q); + + const signed char* weight_hc_I = weight_hc_dr.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc_dr.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc_dr.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc_dr.row(hidden_size * 3 + q); + + signed char* kptr = weight_data_tm_dr.row(q); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); + + int i = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_I[i + 2]; + kptr[3] = weight_xc_I[i + 3]; + kptr[4] = weight_xc_F[i]; + kptr[5] = weight_xc_F[i + 1]; + kptr[6] = weight_xc_F[i + 2]; + kptr[7] = weight_xc_F[i + 3]; + kptr[8 + 0] = weight_xc_O[i]; + kptr[8 + 1] = weight_xc_O[i + 1]; + kptr[8 + 2] = weight_xc_O[i + 2]; + kptr[8 + 3] = weight_xc_O[i + 3]; + kptr[8 + 4] = weight_xc_G[i]; + kptr[8 + 5] = weight_xc_G[i + 1]; + kptr[8 + 6] = weight_xc_G[i + 2]; + kptr[8 + 7] = weight_xc_G[i + 3]; + kptr += 16; + } +#else + for (; i + 7 < size; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_xc_I + i)); + vst1_s8(kptr + 8, vld1_s8(weight_xc_F + i)); + vst1_s8(kptr + 16, vld1_s8(weight_xc_O + i)); + vst1_s8(kptr + 24, vld1_s8(weight_xc_G + i)); + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_F[i]; + kptr[3] = weight_xc_F[i + 1]; + kptr[4] = weight_xc_O[i]; + kptr[5] = weight_xc_O[i + 1]; + kptr[6] = weight_xc_G[i]; + kptr[7] = weight_xc_G[i + 1]; + kptr += 8; + } +#endif // __ARM_NEON + for (; i < size; i++) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_F[i]; + kptr[2] = weight_xc_O[i]; + kptr[3] = weight_xc_G[i]; + kptr += 4; + } + + i = 0; +#if __ARM_NEON +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_I[i + 2]; + kptr[3] = weight_hc_I[i + 3]; + kptr[4] = weight_hc_F[i]; + kptr[5] = weight_hc_F[i + 1]; + kptr[6] = weight_hc_F[i + 2]; + kptr[7] = weight_hc_F[i + 3]; + kptr[8 + 0] = weight_hc_O[i]; + kptr[8 + 1] = weight_hc_O[i + 1]; + kptr[8 + 2] = weight_hc_O[i + 2]; + kptr[8 + 3] = weight_hc_O[i + 3]; + kptr[8 + 4] = weight_hc_G[i]; + kptr[8 + 5] = weight_hc_G[i + 1]; + kptr[8 + 6] = weight_hc_G[i + 2]; + kptr[8 + 7] = weight_hc_G[i + 3]; + kptr += 16; + } +#else + for (; i + 7 < num_output; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_hc_I + i)); + vst1_s8(kptr + 8, vld1_s8(weight_hc_F + i)); + vst1_s8(kptr + 16, vld1_s8(weight_hc_O + i)); + vst1_s8(kptr + 24, vld1_s8(weight_hc_G + i)); + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_F[i]; + kptr[3] = weight_hc_F[i + 1]; + kptr[4] = weight_hc_O[i]; + kptr[5] = weight_hc_O[i + 1]; + kptr[6] = weight_hc_G[i]; + kptr[7] = weight_hc_G[i + 1]; + kptr += 8; + } +#endif // __ARM_NEON + for (; i < num_output; i++) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_F[i]; + kptr[2] = weight_hc_O[i]; + kptr[3] = weight_hc_G[i]; + kptr += 4; + } + + descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[2] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[3] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q]; + descales_ptr[4] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[5] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[6] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[7] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q]; + } + } +} + +static void lstm_int8_gate_output(const Mat& gates, const Mat& weight_hr, Mat& hidden_state, Mat& tmp_hidden_state, Mat& cell_state, Mat& top_blob, int ti, int elemtype, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2) + if (ncnn::cpu_support_arm_vfpv4()) + { + lstm_int8_gate_output_vfpv4(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt); + return; + } +#endif + + const int num_output = top_blob.w; + const int hidden_size = cell_state.w; + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + float* output_data = top_blob.row(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + int remain_hidden_size_start = 0; +#if __ARM_NEON + int nn_hidden_size = hidden_size >> 2; + remain_hidden_size_start = nn_hidden_size << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const float* gates_data = gates.row(q); + + float32x4x4_t _IFOG_4x4 = vld4q_f32(gates_data); + + float32x4_t _lstm_I = sigmoid_ps(_IFOG_4x4.val[0]); + float32x4_t _lstm_F = sigmoid_ps(_IFOG_4x4.val[1]); + float32x4_t _lstm_O = sigmoid_ps(_IFOG_4x4.val[2]); + float32x4_t _lstm_G = tanh_ps(_IFOG_4x4.val[3]); + + float32x4_t _cell2 = vaddq_f32(vmulq_f32(_lstm_F, vld1q_f32(cell_ptr + q)), vmulq_f32(_lstm_I, _lstm_G)); + float32x4_t _lstm_H = vmulq_f32(_lstm_O, tanh_ps(_cell2)); + + vst1q_f32(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + vst1q_f32(hidden_ptr + q, _lstm_H); + + if (elemtype == 1) + { + // fp32 + vst1q_f32(output_data + q, _lstm_H); + } + if (elemtype == 2) + { + // fp16 + unsigned short* outptr = (unsigned short*)output_data + q; +#if (__ARM_FP & 2) +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "fcvtn v0.4h, %2.4s \n" + "st1 {v0.4h}, [%0] \n" + : "=r"(outptr) // %0 + : "0"(outptr), + "w"(_lstm_H) + : "memory", "v0"); +#else // __aarch64__ + asm volatile( + "vcvt.f16.f32 d0, %q2 \n" + "vst1.u16 {d0}, [%0] \n" + : "=r"(outptr) // %0 + : "0"(outptr), + "w"(_lstm_H) + : "memory", "q0"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + vst1_u16(outptr, (uint16x4_t)vcvt_f16_f32(_lstm_H)); +#endif // NCNN_GNU_INLINE_ASM +#else + outptr[q] = float32_to_float16(hidden_ptr[q]); + outptr[q + 1] = float32_to_float16(hidden_ptr[q + 1]); + outptr[q + 2] = float32_to_float16(hidden_ptr[q + 2]); + outptr[q + 3] = float32_to_float16(hidden_ptr[q + 3]); +#endif // (__ARM_FP & 2) + } + if (elemtype == 4) + { + // bf16 + vst1_u16((unsigned short*)output_data + q, float2bfloat(_lstm_H)); + } + } + else + { + vst1q_f32(tmp_hidden_ptr + q, _lstm_H); + } + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + + if (elemtype == 1) + { + output_data[q] = H; + } + if (elemtype == 2) + { + ((unsigned short*)output_data)[q] = float32_to_float16(H); + } + if (elemtype == 4) + { + ((unsigned short*)output_data)[q] = float32_to_bfloat16(H); + } + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + + if (elemtype == 1) + { + output_data[q] = H; + } + if (elemtype == 2) + { + ((unsigned short*)output_data)[q] = float32_to_float16(H); + } + if (elemtype == 4) + { + ((unsigned short*)output_data)[q] = float32_to_bfloat16(H); + } + } + } +} + +static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD + if (ncnn::cpu_support_arm_asimddp()) + { + lstm_int8_asimddp(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); + return; + } +#endif + + int size = bottom_blob_int8.w; + int T = bottom_blob_int8.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + float hidden_state_int8_scale = 1.f; + float hidden_state_int8_descale = 1.f; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scale = 127.f / absmax; + hidden_state_int8_descale = absmax / 127.f; + + signed char* hs = hidden_state_int8; + for (int i = 0; i < num_output; i++) + { + hs[i] = float2int8(hidden_state[i] * hidden_state_int8_scale); + } + } + } + + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + // gate reset update + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + + const signed char* kptr = weight_data_tm.row(q); + const float* descales_ptr = weight_data_tm_int8_descales.row(q); + + float* gates_data = gates.row(q); + +#if __ARM_NEON + int32x4_t _lstm_IFOGx0 = vdupq_n_s32(0); + int i = 0; +#if __ARM_FEATURE_DOTPROD + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _lstm_IFOGx0 = vdotq_laneq_s32(_lstm_IFOGx0, _w0, _xi, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w1, _xi, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w2, _xi, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _xi, 3); + + kptr += 64; + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _lstm_IFOGx0 = vdotq_lane_s32(_lstm_IFOGx0, _w0, _xi, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _xi, 1); + + kptr += 32; + } + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum3); +#else + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_xi)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_xi)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_xi)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_xi)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_xi)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_xi)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = vaddq_s32(_lstm_IFOGx0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w = vld1q_s8(kptr); + _lstm_IFOGx0 = vdotq_lane_s32(_lstm_IFOGx0, _w, _xi, 0); +#else + int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _lstm_IFOGx = vmull_s8(vget_low_s8(_w01), _xi0); + _lstm_IFOGx = vmlal_s8(_lstm_IFOGx, vget_high_s8(_w01), _xi1); + _lstm_IFOGx0 = vpadalq_s16(_lstm_IFOGx0, _lstm_IFOGx); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < size; i += 2) + { + int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGx = vmull_s8(_w, _xi); + _lstm_IFOGx0 = vpadalq_s16(_lstm_IFOGx0, _lstm_IFOGx); + + kptr += 8; + } + for (; i < size; i++) + { + int8x8_t _xi = vdup_n_s8(x[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGx = vmull_s8(_w, _xi); + _lstm_IFOGx0 = vaddw_s16(_lstm_IFOGx0, vget_low_s16(_lstm_IFOGx)); + + kptr += 4; + } + + int32x4_t _lstm_IFOGh0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _lstm_IFOGh0 = vdotq_laneq_s32(_lstm_IFOGh0, _w0, _h_cont, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w1, _h_cont, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w2, _h_cont, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _h_cont, 3); + + kptr += 64; + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _lstm_IFOGh0 = vdotq_lane_s32(_lstm_IFOGh0, _w0, _h_cont, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _h_cont, 1); + + kptr += 32; + } + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum3); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_h_cont)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_h_cont)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_h_cont)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_h_cont)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_h_cont)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_h_cont)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = vaddq_s32(_lstm_IFOGh0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w = vld1q_s8(kptr); + _lstm_IFOGh0 = vdotq_lane_s32(_lstm_IFOGh0, _w, _h_cont, 0); +#else + int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _lstm_IFOGh = vmull_s8(vget_low_s8(_w01), _h_cont0); + _lstm_IFOGh = vmlal_s8(_lstm_IFOGh, vget_high_s8(_w01), _h_cont1); + _lstm_IFOGh0 = vpadalq_s16(_lstm_IFOGh0, _lstm_IFOGh); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < num_output; i += 2) + { + int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGh = vmull_s8(_w, _h_cont); + _lstm_IFOGh0 = vpadalq_s16(_lstm_IFOGh0, _lstm_IFOGh); + + kptr += 8; + } + for (; i < num_output; i++) + { + int8x8_t _h_cont = vdup_n_s8(hs[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _lstm_IFOGh = vmull_s8(_w, _h_cont); + _lstm_IFOGh0 = vaddw_s16(_lstm_IFOGh0, vget_low_s16(_lstm_IFOGh)); + + kptr += 4; + } + + float32x4_t _descale_x = vdupq_n_f32(descale_x); + float32x4_t _descale_h = vdupq_n_f32(descale_h); + + float32x4_t _lstm_IFOG0 = vld1q_f32(bias_c_IFOG); + + float32x4_t _descale_xc_IFOG = vld1q_f32(descales_ptr); + + _lstm_IFOG0 = vmlaq_f32(_lstm_IFOG0, vcvtq_f32_s32(_lstm_IFOGx0), vmulq_f32(_descale_x, _descale_xc_IFOG)); + + float32x4_t _descale_hc_IFOG = vld1q_f32(descales_ptr + 4); + + _lstm_IFOG0 = vmlaq_f32(_lstm_IFOG0, vcvtq_f32_s32(_lstm_IFOGh0), vmulq_f32(_descale_h, _descale_hc_IFOG)); + + vst1q_f32(gates_data, _lstm_IFOG0); +#else + int Ix = 0; + int Fx = 0; + int Ox = 0; + int Gx = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Ix += kptr[0] * xi; + Fx += kptr[1] * xi; + Ox += kptr[2] * xi; + Gx += kptr[3] * xi; + + kptr += 4; + } + + int Ih = 0; + int Fh = 0; + int Oh = 0; + int Gh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Ih += kptr[0] * h_cont; + Fh += kptr[1] * h_cont; + Oh += kptr[2] * h_cont; + Gh += kptr[3] * h_cont; + + kptr += 4; + } + + const float descale_xc_I = descales_ptr[0]; + const float descale_xc_F = descales_ptr[1]; + const float descale_xc_O = descales_ptr[2]; + const float descale_xc_G = descales_ptr[3]; + const float descale_hc_I = descales_ptr[4]; + const float descale_hc_F = descales_ptr[5]; + const float descale_hc_O = descales_ptr[6]; + const float descale_hc_G = descales_ptr[7]; + + float I = bias_c_IFOG[0] + Ix * (descale_x * descale_xc_I) + Ih * (descale_h * descale_hc_I); + float F = bias_c_IFOG[1] + Fx * (descale_x * descale_xc_F) + Fh * (descale_h * descale_hc_F); + float O = bias_c_IFOG[2] + Ox * (descale_x * descale_xc_O) + Oh * (descale_h * descale_hc_O); + float G = bias_c_IFOG[3] + Gx * (descale_x * descale_xc_G) + Gh * (descale_h * descale_hc_G); + + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; +#endif // __ARM_NEON + } + + lstm_int8_gate_output(gates, weight_hr, hidden_state, tmp_hidden_state, cell_state, top_blob, ti, elemtype, opt); + } +} diff --git a/src/layer/arm/neon_mathfun.h b/src/layer/arm/neon_mathfun.h index 2b4094a9ed7..537f8c1b641 100644 --- a/src/layer/arm/neon_mathfun.h +++ b/src/layer/arm/neon_mathfun.h @@ -276,7 +276,7 @@ static inline float32x4_t div_ps(float32x4_t a, float32x4_t b) #else float32x4_t reciprocal = vrecpeq_f32(b); reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal); - // reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal); + reciprocal = vmulq_f32(vrecpsq_f32(b, reciprocal), reciprocal); return vmulq_f32(a, reciprocal); #endif } @@ -302,7 +302,7 @@ static inline float32x4_t sigmoid_ps(float32x4_t _v) _v = exp_ps(_v); _v = vaddq_f32(_v, _one); float32x4_t _outp = vrecpeq_f32(_v); - // _outp = vmulq_f32(vrecpsq_f32(_v, _outp), _outp); + _outp = vmulq_f32(vrecpsq_f32(_v, _outp), _outp); return vmulq_f32(vrecpsq_f32(_v, _outp), _outp); } diff --git a/src/layer/arm/rnn_arm.cpp b/src/layer/arm/rnn_arm.cpp index 293322b8488..8177448a32e 100644 --- a/src/layer/arm/rnn_arm.cpp +++ b/src/layer/arm/rnn_arm.cpp @@ -25,6 +25,10 @@ namespace ncnn { +#if NCNN_INT8 +#include "rnn_int8.h" +#endif + RNN_arm::RNN_arm() { #if __ARM_NEON @@ -40,6 +44,13 @@ RNN_arm::RNN_arm() int RNN_arm::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage) { @@ -54,12 +65,16 @@ int RNN_arm::create_pipeline(const Option& opt) } #endif - int num_directions = direction == 2 ? 2 : 1; - int size = weight_data_size / num_directions / num_output; + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / num_output; #if __ARM_NEON weight_xc_data_packed.create(size * 4, num_output / 4 + num_output % 4, num_directions); weight_hc_data_packed.create(num_output * 4, num_output / 4 + num_output % 4, num_directions); +#else + weight_xc_data_packed.create(size, num_output, num_directions); + weight_hc_data_packed.create(num_output, num_output, num_directions); +#endif #pragma omp parallel for num_threads(opt.num_threads) for (int dr = 0; dr < num_directions; dr++) @@ -132,10 +147,6 @@ int RNN_arm::create_pipeline(const Option& opt) } } } -#else - weight_xc_data_packed = weight_xc_data; - weight_hc_data_packed = weight_hc_data; -#endif bias_c_data_packed = bias_c_data; @@ -319,16 +330,18 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blob, top_blob, opt); + } +#endif + int elembits = bottom_blob.elembits(); #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blob, top_blob, opt); - else - return forward_fp16s(bottom_blob, top_blob, opt); - } + return forward_fp16s(bottom_blob, top_blob, opt); #endif #if NCNN_BF16 @@ -368,15 +381,19 @@ int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c if (top_blob_reverse.empty()) return -100; - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -395,17 +412,19 @@ int RNN_arm::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) c int RNN_arm::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& bottom_blob = bottom_blobs[0]; int elembits = bottom_blob.elembits(); #if NCNN_ARM82 if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) - { - if (opt.use_fp16_arithmetic) - return forward_fp16sa(bottom_blobs, top_blobs, opt); - else - return forward_fp16s(bottom_blobs, top_blobs, opt); - } + return forward_fp16s(bottom_blobs, top_blobs, opt); #endif #if NCNN_BF16 @@ -454,14 +473,18 @@ int RNN_arm::forward(const std::vector& bottom_blobs, std::vector& top return -100; Mat hidden0 = hidden.row_range(0, 1); - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -783,15 +806,19 @@ int RNN_arm::forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = rnn_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; + { + int ret = rnn_bf16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); - int ret1 = rnn_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; + { + int ret = rnn_bf16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -854,14 +881,18 @@ int RNN_arm::forward_bf16s(const std::vector& bottom_blobs, std::vector& bottom_blobs, std::vector(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(float16_to_float32(x[i]))); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } + if (elemtype == 4) + { + // bf16 + for (int t = 0; t < T; t++) + { + const unsigned short* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(bfloat16_to_float32(x[i]))); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + } + + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt); +} + +int RNN_arm::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); + + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 + } + + // *INDENT-ON* + // clang-format on + } + + int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; + + int num_directions = direction == 2 ? 2 : 1; + + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + top_blob.create(num_output * num_directions, T, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, elemtype, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + rnn_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + { + rnn_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, elemtype, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); + } + + hidden.fill(0.0f); + + { + rnn_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, elemtype, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), hidden, opt); + } + + // concat w + for (int i = 0; i < T; i++) + { + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); + } + } + + return 0; +} + +int RNN_arm::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + + int elemtype = 1; // fp32 + { + int elembits = bottom_blob.elembits(); + + // clang-format off + // *INDENT-OFF* + +#if NCNN_ARM82 + if (support_fp16_storage && opt.use_fp16_storage && elembits == 16) + { + elemtype = 2; // fp16 + } + else +#endif +#if NCNN_BF16 + if (opt.use_bf16_storage && elembits == 16) + { + elemtype = 4; // bf16 + } + else +#endif + { + // fp32 + } + + // *INDENT-ON* + // clang-format on + } + + int T = bottom_blob.h; + size_t elemsize = bottom_blob.elemsize; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 2) + { + if (elemtype == 1) + { + hidden = bottom_blobs[1].clone(hidden_allocator); + } + if (elemtype == 2) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); + } + if (elemtype == 4) + { + Option opt_cast = opt; + opt_cast.blob_allocator = hidden_allocator; + cast_bfloat16_to_float32(bottom_blobs[1], hidden, opt_cast); + } + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, elemsize, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, elemtype, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + rnn_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, elemsize, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + { + rnn_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, elemtype, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), hidden0, opt); + } + + Mat hidden1 = hidden.row_range(1, 1); + { + rnn_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, elemtype, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), hidden1, opt); + } + + // concat w + for (int i = 0; i < T; i++) + { + const unsigned char* pf = top_blob_forward.row(i); + const unsigned char* pr = top_blob_reverse.row(i); + unsigned char* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * elemsize); + memcpy(ptr + num_output * elemsize, pr, num_output * elemsize); + } + } + + if (top_blobs.size() == 2) + { + if (elemtype == 1) + { + top_blobs[1] = hidden; + } + if (elemtype == 2) + { + cast_float32_to_float16(hidden, top_blobs[1], opt); + } + if (elemtype == 4) + { + cast_float32_to_bfloat16(hidden, top_blobs[1], opt); + } + } + + return 0; +} +#endif // NCNN_INT8 + } // namespace ncnn diff --git a/src/layer/arm/rnn_arm.h b/src/layer/arm/rnn_arm.h index 18e75642b9e..ef07a4f6b52 100644 --- a/src/layer/arm/rnn_arm.h +++ b/src/layer/arm/rnn_arm.h @@ -33,19 +33,29 @@ class RNN_arm : public RNN int create_pipeline_fp16s(const Option& opt); int forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_fp16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; - int forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; - int forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif #if NCNN_BF16 int create_pipeline_bf16s(const Option& opt); int forward_bf16s(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; int forward_bf16s(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; #endif +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + void dynamic_quantize(const Mat& bottom_blob, int elemtype, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const; + int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif public: Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + + Mat weight_data_tm; + +#if NCNN_INT8 + Mat weight_data_tm_int8_descales; +#endif }; } // namespace ncnn diff --git a/src/layer/arm/rnn_arm_asimddp.cpp b/src/layer/arm/rnn_arm_asimddp.cpp new file mode 100644 index 00000000000..6e4890de91b --- /dev/null +++ b/src/layer/arm/rnn_arm_asimddp.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "arm_activation.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "rnn_int8.h" + +void rnn_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt) +{ + rnn_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, opt); +} + +void rnn_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt) +{ + rnn_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, hidden_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/arm/rnn_arm_asimdhp.cpp b/src/layer/arm/rnn_arm_asimdhp.cpp index 93b009151c5..c8ef6898ce6 100644 --- a/src/layer/arm/rnn_arm_asimdhp.cpp +++ b/src/layer/arm/rnn_arm_asimdhp.cpp @@ -23,148 +23,6 @@ namespace ncnn { #if __ARM_FEATURE_FP16_VECTOR_ARITHMETIC -static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) -{ - int size = bottom_blob.w; - int T = bottom_blob.h; - - int num_output = top_blob.w; - - // num_output - Mat gates(num_output, 4u, opt.workspace_allocator); - if (gates.empty()) - return -100; - - // unroll - for (int t = 0; t < T; t++) - { - int ti = reverse ? T - 1 - t : t; - - const __fp16* x = bottom_blob.row(ti); - - int nn_num_output = num_output >> 2; - int remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - const __fp16* weight_xc_ptr = weight_xc.row(q / 4); - const __fp16* weight_hc_ptr = weight_hc.row(q / 4); - - float32x4_t _rnn_H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); - float32x4_t _sum1 = vdupq_n_f32(0.f); - float32x4_t _sum2 = vdupq_n_f32(0.f); - float32x4_t _sum3 = vdupq_n_f32(0.f); - - int i = 0; - for (; i + 3 < size; i += 4) - { - float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i)); - float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); - float32x4_t _weight_xc_1 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 4)); - float32x4_t _weight_xc_2 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 8)); - float32x4_t _weight_xc_3 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 12)); - _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc, _x, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); - - weight_xc_ptr += 16; - } - for (; i < size; i++) - { - float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); - float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); - _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); - - weight_xc_ptr += 4; - } - - i = 0; - for (; i + 3 < num_output; i += 4) - { - float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); - float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); - float32x4_t _weight_hc_1 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 4)); - float32x4_t _weight_hc_2 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 8)); - float32x4_t _weight_hc_3 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 12)); - _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc, _hidden_state, 0); - _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); - _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); - _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); - - weight_hc_ptr += 16; - } - for (; i < num_output; i++) - { - float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); - float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); - _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); - - weight_hc_ptr += 4; - } - - _rnn_H = vaddq_f32(_rnn_H, _sum1); - _sum2 = vaddq_f32(_sum2, _sum3); - _rnn_H = vaddq_f32(_rnn_H, _sum2); - - _rnn_H = tanh_ps(_rnn_H); - - vst1q_f32((float*)gates + q, _rnn_H); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - const __fp16* weight_xc_ptr = weight_xc.row(q / 4 + q % 4); - const __fp16* weight_hc_ptr = weight_hc.row(q / 4 + q % 4); - - float H = (float)(((const __fp16*)bias_c)[q]); - - for (int i = 0; i < size; i++) - { - H += (float)weight_xc_ptr[i] * (float)x[i]; - } - - for (int i = 0; i < num_output; i++) - { - H += (float)weight_hc_ptr[i] * hidden_state[i]; - } - - H = tanhf(H); - - gates[q] = H; - } - - __fp16* output_data = top_blob.row<__fp16>(ti); - - float* hidden_ptr = hidden_state; - - nn_num_output = num_output >> 2; - remain_num_output_start = nn_num_output << 2; - #pragma omp parallel for num_threads(opt.num_threads) - for (int qq = 0; qq < nn_num_output; qq++) - { - int q = qq * 4; - - float32x4_t _rnn_H = vld1q_f32((float*)gates + q); - - vst1q_f32(hidden_ptr + q, _rnn_H); - vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); - } - #pragma omp parallel for num_threads(opt.num_threads) - for (int q = remain_num_output_start; q < num_output; q++) - { - float H = gates[q]; - - hidden_ptr[q] = H; - output_data[q] = (__fp16)H; - } - } - - return 0; -} - static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) { int size = bottom_blob.w; @@ -380,6 +238,151 @@ static int rnn_fp16sa(const Mat& bottom_blob, Mat& top_blob, int reverse, const return 0; } +static int rnn_fp16s(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc, const Mat& bias_c, const Mat& weight_hc, Mat& hidden_state, const Option& opt) +{ + if (opt.use_fp16_arithmetic) + return rnn_fp16sa(bottom_blob, top_blob, reverse, weight_xc, bias_c, weight_hc, hidden_state, opt); + + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + const __fp16* x = bottom_blob.row(ti); + + int nn_num_output = num_output >> 2; + int remain_num_output_start = nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const __fp16* weight_xc_ptr = weight_xc.row(q / 4); + const __fp16* weight_hc_ptr = weight_hc.row(q / 4); + + float32x4_t _rnn_H = vcvt_f32_f16(vld1_f16((const __fp16*)bias_c + q)); + float32x4_t _sum1 = vdupq_n_f32(0.f); + float32x4_t _sum2 = vdupq_n_f32(0.f); + float32x4_t _sum3 = vdupq_n_f32(0.f); + + int i = 0; + for (; i + 3 < size; i += 4) + { + float32x4_t _x = vcvt_f32_f16(vld1_f16(x + i)); + float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); + float32x4_t _weight_xc_1 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 4)); + float32x4_t _weight_xc_2 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 8)); + float32x4_t _weight_xc_3 = vcvt_f32_f16(vld1_f16(weight_xc_ptr + 12)); + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_xc, _x, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_xc_1, _x, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_xc_2, _x, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_xc_3, _x, 3); + + weight_xc_ptr += 16; + } + for (; i < size; i++) + { + float32x4_t _x = vcvt_f32_f16(vdup_n_f16(x[i])); + float32x4_t _weight_xc = vcvt_f32_f16(vld1_f16(weight_xc_ptr)); + _rnn_H = vfmaq_f32(_rnn_H, _weight_xc, _x); + + weight_xc_ptr += 4; + } + + i = 0; + for (; i + 3 < num_output; i += 4) + { + float32x4_t _hidden_state = vld1q_f32((const float*)hidden_state + i); + float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); + float32x4_t _weight_hc_1 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 4)); + float32x4_t _weight_hc_2 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 8)); + float32x4_t _weight_hc_3 = vcvt_f32_f16(vld1_f16(weight_hc_ptr + 12)); + _rnn_H = vfmaq_laneq_f32(_rnn_H, _weight_hc, _hidden_state, 0); + _sum1 = vfmaq_laneq_f32(_sum1, _weight_hc_1, _hidden_state, 1); + _sum2 = vfmaq_laneq_f32(_sum2, _weight_hc_2, _hidden_state, 2); + _sum3 = vfmaq_laneq_f32(_sum3, _weight_hc_3, _hidden_state, 3); + + weight_hc_ptr += 16; + } + for (; i < num_output; i++) + { + float32x4_t _hidden_state = vdupq_n_f32(hidden_state[i]); + float32x4_t _weight_hc = vcvt_f32_f16(vld1_f16(weight_hc_ptr)); + _rnn_H = vfmaq_f32(_rnn_H, _weight_hc, _hidden_state); + + weight_hc_ptr += 4; + } + + _rnn_H = vaddq_f32(_rnn_H, _sum1); + _sum2 = vaddq_f32(_sum2, _sum3); + _rnn_H = vaddq_f32(_rnn_H, _sum2); + + _rnn_H = tanh_ps(_rnn_H); + + vst1q_f32((float*)gates + q, _rnn_H); + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const __fp16* weight_xc_ptr = weight_xc.row(q / 4 + q % 4); + const __fp16* weight_hc_ptr = weight_hc.row(q / 4 + q % 4); + + float H = (float)(((const __fp16*)bias_c)[q]); + + for (int i = 0; i < size; i++) + { + H += (float)weight_xc_ptr[i] * (float)x[i]; + } + + for (int i = 0; i < num_output; i++) + { + H += (float)weight_hc_ptr[i] * hidden_state[i]; + } + + H = tanhf(H); + + gates[q] = H; + } + + __fp16* output_data = top_blob.row<__fp16>(ti); + + float* hidden_ptr = hidden_state; + + nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + vst1_f16(output_data + q, vcvt_f16_f32(_rnn_H)); + } + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + output_data[q] = (__fp16)H; + } + } + + return 0; +} + int RNN_arm::create_pipeline_fp16s(const Option& opt) { int num_directions = direction == 2 ? 2 : 1; @@ -561,15 +564,19 @@ int RNN_arm::forward_fp16s(const Mat& bottom_blob, Mat& top_blob, const Option& if (top_blob_reverse.empty()) return -100; - int ret0 = rnn_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; + { + int ret = rnn_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.f); - int ret1 = rnn_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; + { + int ret = rnn_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -632,148 +639,18 @@ int RNN_arm::forward_fp16s(const std::vector& bottom_blobs, std::vector(i); - const __fp16* pr = top_blob_reverse.row(i); - __fp16* ptr = top_blob.row<__fp16>(i); - - memcpy(ptr, pf, num_output * sizeof(__fp16)); - memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + int ret = rnn_fp16s(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); + if (ret != 0) + return ret; } - } - if (top_blobs.size() == 2) - { - cast_float32_to_float16(hidden, top_blobs[1], opt); - } - - return 0; -} - -int RNN_arm::forward_fp16sa(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const -{ - int T = bottom_blob.h; - - int num_directions = direction == 2 ? 2 : 1; - - // initial hidden state - Mat hidden(num_output, 4u, opt.workspace_allocator); - if (hidden.empty()) - return -100; - hidden.fill(0.f); - - top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = rnn_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - int ret0 = rnn_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; - - hidden.fill(0.f); - - int ret1 = rnn_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; - - // concat w - for (int i = 0; i < T; i++) + Mat hidden1 = hidden.row_range(1, 1); { - const __fp16* pf = top_blob_forward.row(i); - const __fp16* pr = top_blob_reverse.row(i); - __fp16* ptr = top_blob.row<__fp16>(i); - - memcpy(ptr, pf, num_output * sizeof(__fp16)); - memcpy(ptr + num_output, pr, num_output * sizeof(__fp16)); + int ret = rnn_fp16s(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); + if (ret != 0) + return ret; } - } - - return 0; -} - -int RNN_arm::forward_fp16sa(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const -{ - const Mat& bottom_blob = bottom_blobs[0]; - int T = bottom_blob.h; - int num_directions = direction == 2 ? 2 : 1; - - Mat hidden; - Allocator* hidden_allocator = top_blobs.size() == 2 ? opt.blob_allocator : opt.workspace_allocator; - if (bottom_blobs.size() == 2) - { - Option opt_cast = opt; - opt_cast.blob_allocator = hidden_allocator; - cast_float16_to_float32(bottom_blobs[1], hidden, opt_cast); - } - else - { - hidden.create(num_output, num_directions, 4u, hidden_allocator); - if (hidden.empty()) - return -100; - hidden.fill(0.f); - } - - Mat& top_blob = top_blobs[0]; - top_blob.create(num_output * num_directions, T, 2u, opt.blob_allocator); - if (top_blob.empty()) - return -100; - - // Uni directional - if (direction == 0 || direction == 1) - { - int ret = rnn_fp16sa(bottom_blob, top_blob, direction, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden, opt); - if (ret != 0) - return ret; - } - - if (direction == 2) - { - Mat top_blob_forward(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_forward.empty()) - return -100; - - Mat top_blob_reverse(num_output, T, 2u, opt.workspace_allocator); - if (top_blob_reverse.empty()) - return -100; - - Mat hidden0 = hidden.row_range(0, 1); - int ret0 = rnn_fp16sa(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; - - Mat hidden1 = hidden.row_range(1, 1); - int ret1 = rnn_fp16sa(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; // concat w for (int i = 0; i < T; i++) diff --git a/src/layer/arm/rnn_arm_vfpv4.cpp b/src/layer/arm/rnn_arm_vfpv4.cpp new file mode 100644 index 00000000000..893f6e061b1 --- /dev/null +++ b/src/layer/arm/rnn_arm_vfpv4.cpp @@ -0,0 +1,30 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "arm_activation.h" +#include "arm_usability.h" + +namespace ncnn { + +#include "rnn_int8.h" + +void rnn_int8_gate_output_vfpv4(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt) +{ + rnn_int8_gate_output(gates, hidden_state, top_blob, ti, elemtype, opt); +} + +} // namespace ncnn diff --git a/src/layer/arm/rnn_int8.h b/src/layer/arm/rnn_int8.h new file mode 100644 index 00000000000..2850fe9cc98 --- /dev/null +++ b/src/layer/arm/rnn_int8.h @@ -0,0 +1,769 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD +void rnn_transform_weight_int8_asimddp(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt); +void rnn_int8_asimddp(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2) +void rnn_int8_gate_output_vfpv4(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt); +#endif + +static void rnn_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD + if (ncnn::cpu_support_arm_asimddp()) + { + rnn_transform_weight_int8_asimddp(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, opt); + return; + } +#endif + +#if __ARM_NEON + weight_data_tm.create(size * 4 + num_output * 4, num_output / 4 + num_output % 4, num_directions, 1u, 1); + weight_data_tm_int8_descales.create(4 + 4, num_output / 4 + num_output % 4, num_directions); +#else + weight_data_tm.create(size + num_output, num_output, num_directions, 1u, 1); + weight_data_tm_int8_descales.create(1 + 1, num_output, num_directions); +#endif + bias_c_tm = bias_c; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc_dr = weight_xc.channel(dr); + const Mat weight_hc_dr = weight_hc.channel(dr); + const float* weight_xc_int8_scales_ptr = weight_xc_int8_scales.row(dr); + const float* weight_hc_int8_scales_ptr = weight_hc_int8_scales.row(dr); + + Mat weight_data_tm_dr = weight_data_tm.channel(dr); + Mat weight_data_tm_int8_descales_dr = weight_data_tm_int8_descales.channel(dr); + + int q = 0; +#if __ARM_NEON + for (; q + 3 < num_output; q += 4) + { + const signed char* weight_xc_0 = weight_xc_dr.row(q); + const signed char* weight_xc_1 = weight_xc_dr.row(q + 1); + const signed char* weight_xc_2 = weight_xc_dr.row(q + 2); + const signed char* weight_xc_3 = weight_xc_dr.row(q + 3); + + const signed char* weight_hc_0 = weight_hc_dr.row(q); + const signed char* weight_hc_1 = weight_hc_dr.row(q + 1); + const signed char* weight_hc_2 = weight_hc_dr.row(q + 2); + const signed char* weight_hc_3 = weight_hc_dr.row(q + 3); + + signed char* kptr = weight_data_tm_dr.row(q / 4); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4); + + int i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_0[i]; + kptr[1] = weight_xc_0[i + 1]; + kptr[2] = weight_xc_0[i + 2]; + kptr[3] = weight_xc_0[i + 3]; + kptr[4] = weight_xc_1[i]; + kptr[5] = weight_xc_1[i + 1]; + kptr[6] = weight_xc_1[i + 2]; + kptr[7] = weight_xc_1[i + 3]; + kptr[8 + 0] = weight_xc_2[i]; + kptr[8 + 1] = weight_xc_2[i + 1]; + kptr[8 + 2] = weight_xc_2[i + 2]; + kptr[8 + 3] = weight_xc_2[i + 3]; + kptr[8 + 4] = weight_xc_3[i]; + kptr[8 + 5] = weight_xc_3[i + 1]; + kptr[8 + 6] = weight_xc_3[i + 2]; + kptr[8 + 7] = weight_xc_3[i + 3]; + + kptr += 16; + } +#else + for (; i + 7 < size; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_xc_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_xc_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_xc_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_xc_3 + i)); + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_0[i]; + kptr[1] = weight_xc_0[i + 1]; + kptr[2] = weight_xc_1[i]; + kptr[3] = weight_xc_1[i + 1]; + kptr[4] = weight_xc_2[i]; + kptr[5] = weight_xc_2[i + 1]; + kptr[6] = weight_xc_3[i]; + kptr[7] = weight_xc_3[i + 1]; + + kptr += 8; + } + for (; i < size; i++) + { + kptr[0] = weight_xc_0[i]; + kptr[1] = weight_xc_1[i]; + kptr[2] = weight_xc_2[i]; + kptr[3] = weight_xc_3[i]; + + kptr += 4; + } + + i = 0; +#if __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_0[i]; + kptr[1] = weight_hc_0[i + 1]; + kptr[2] = weight_hc_0[i + 2]; + kptr[3] = weight_hc_0[i + 3]; + kptr[4] = weight_hc_1[i]; + kptr[5] = weight_hc_1[i + 1]; + kptr[6] = weight_hc_1[i + 2]; + kptr[7] = weight_hc_1[i + 3]; + kptr[8 + 0] = weight_hc_2[i]; + kptr[8 + 1] = weight_hc_2[i + 1]; + kptr[8 + 2] = weight_hc_2[i + 2]; + kptr[8 + 3] = weight_hc_2[i + 3]; + kptr[8 + 4] = weight_hc_3[i]; + kptr[8 + 5] = weight_hc_3[i + 1]; + kptr[8 + 6] = weight_hc_3[i + 2]; + kptr[8 + 7] = weight_hc_3[i + 3]; + + kptr += 16; + } +#else + for (; i + 7 < num_output; i += 8) + { + vst1_s8(kptr, vld1_s8(weight_hc_0 + i)); + vst1_s8(kptr + 8, vld1_s8(weight_hc_1 + i)); + vst1_s8(kptr + 16, vld1_s8(weight_hc_2 + i)); + vst1_s8(kptr + 24, vld1_s8(weight_hc_3 + i)); + kptr += 32; + } +#endif // __ARM_FEATURE_DOTPROD + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_0[i]; + kptr[1] = weight_hc_0[i + 1]; + kptr[2] = weight_hc_1[i]; + kptr[3] = weight_hc_1[i + 1]; + kptr[4] = weight_hc_2[i]; + kptr[5] = weight_hc_2[i + 1]; + kptr[6] = weight_hc_3[i]; + kptr[7] = weight_hc_3[i + 1]; + + kptr += 8; + } + for (; i < num_output; i++) + { + kptr[0] = weight_hc_0[i]; + kptr[1] = weight_hc_1[i]; + kptr[2] = weight_hc_2[i]; + kptr[3] = weight_hc_3[i]; + + kptr += 4; + } + + float32x4_t _xc = vld1q_f32(weight_xc_int8_scales_ptr + q); + float32x4_t _hc = vld1q_f32(weight_hc_int8_scales_ptr + q); + +#if __aarch64__ + float32x4_t _one = vdupq_n_f32(1.f); + float32x4_t _reciprocal_xc = vdivq_f32(_one, _xc); + float32x4_t _reciprocal_hc = vdivq_f32(_one, _hc); +#else + float32x4_t _reciprocal_xc = vrecpeq_f32(_xc); + _reciprocal_xc = vmulq_f32(vrecpsq_f32(_xc, _reciprocal_xc), _reciprocal_xc); + _reciprocal_xc = vmulq_f32(vrecpsq_f32(_xc, _reciprocal_xc), _reciprocal_xc); + float32x4_t _reciprocal_hc = vrecpeq_f32(_hc); + _reciprocal_hc = vmulq_f32(vrecpsq_f32(_hc, _reciprocal_hc), _reciprocal_hc); + _reciprocal_hc = vmulq_f32(vrecpsq_f32(_hc, _reciprocal_hc), _reciprocal_hc); +#endif + + vst1q_f32(descales_ptr, _reciprocal_xc); + vst1q_f32(descales_ptr + 4, _reciprocal_hc); + } +#endif // __ARM_NEON + for (; q < num_output; q++) + { + const signed char* weight_xc_0 = weight_xc_dr.row(q); + const signed char* weight_hc_0 = weight_hc_dr.row(q); + +#if __ARM_NEON + signed char* kptr = weight_data_tm_dr.row(q / 4 + q % 4); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4 + q % 4); +#else + signed char* kptr = weight_data_tm_dr.row(q); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); +#endif // __ARM_NEON + + for (int i = 0; i < size; i++) + { + kptr[0] = weight_xc_0[i]; + kptr += 1; + } + + for (int i = 0; i < num_output; i++) + { + kptr[0] = weight_hc_0[i]; + kptr += 1; + } + + descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[q]; + descales_ptr[1] = 1.f / weight_hc_int8_scales_ptr[q]; + } + } +} + +static void rnn_int8_gate_output(const Mat& gates, Mat& hidden_state, Mat& top_blob, int ti, int elemtype, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_VFPV4 && __ARM_NEON && !(__ARM_FP & 2) + if (ncnn::cpu_support_arm_vfpv4()) + { + rnn_int8_gate_output_vfpv4(gates, hidden_state, top_blob, ti, elemtype, opt); + return; + } +#endif + + const int num_output = top_blob.w; + + float* output_data = top_blob.row(ti); + + float* hidden_ptr = hidden_state; + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + float32x4_t _rnn_H = vld1q_f32((const float*)gates + q); + + vst1q_f32(hidden_ptr + q, _rnn_H); + + if (elemtype == 1) + { + // fp32 + vst1q_f32(output_data + q, _rnn_H); + } + if (elemtype == 2) + { + // fp16 + unsigned short* outptr = (unsigned short*)output_data + q; +#if (__ARM_FP & 2) +#if NCNN_GNU_INLINE_ASM +#if __aarch64__ + asm volatile( + "fcvtn v0.4h, %2.4s \n" + "st1 {v0.4h}, [%0] \n" + : "=r"(outptr) // %0 + : "0"(outptr), + "w"(_rnn_H) + : "memory", "v0"); +#else // __aarch64__ + asm volatile( + "vcvt.f16.f32 d0, %q2 \n" + "vst1.u16 {d0}, [%0] \n" + : "=r"(outptr) // %0 + : "0"(outptr), + "w"(_rnn_H) + : "memory", "q0"); +#endif // __aarch64__ +#else // NCNN_GNU_INLINE_ASM + vst1_u16(outptr, (uint16x4_t)vcvt_f16_f32(_rnn_H)); +#endif // NCNN_GNU_INLINE_ASM +#else + outptr[q] = float32_to_float16(hidden_ptr[q]); + outptr[q + 1] = float32_to_float16(hidden_ptr[q + 1]); + outptr[q + 2] = float32_to_float16(hidden_ptr[q + 2]); + outptr[q + 3] = float32_to_float16(hidden_ptr[q + 3]); +#endif // (__ARM_FP & 2) + } + if (elemtype == 4) + { + // bf16 + vst1_u16((unsigned short*)output_data + q, float2bfloat(_rnn_H)); + } + } + remain_num_output_start += nn_num_output << 2; +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + float H = gates[q]; + + hidden_ptr[q] = H; + + if (elemtype == 1) + { + output_data[q] = H; + } + if (elemtype == 2) + { + ((unsigned short*)output_data)[q] = float32_to_float16(H); + } + if (elemtype == 4) + { + ((unsigned short*)output_data)[q] = float32_to_bfloat16(H); + } + } +} + +static void rnn_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int elemtype, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, Mat& hidden_state, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_ARM82DOT && __aarch64__ && !__ARM_FEATURE_DOTPROD + if (ncnn::cpu_support_arm_asimddp()) + { + rnn_int8_asimddp(bottom_blob_int8, bottom_blob_int8_descales, top_blob, elemtype, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, hidden_state, opt); + return; + } +#endif + + int size = bottom_blob_int8.w; + int T = bottom_blob_int8.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + float hidden_state_int8_scale = 1.f; + float hidden_state_int8_descale = 1.f; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scale = 127.f / absmax; + hidden_state_int8_descale = absmax / 127.f; + + signed char* hs = hidden_state_int8; + for (int i = 0; i < num_output; i++) + { + hs[i] = float2int8(hidden_state[i] * hidden_state_int8_scale); + } + } + } + + int remain_num_output_start = 0; +#if __ARM_NEON + int nn_num_output = num_output >> 2; + remain_num_output_start = nn_num_output << 2; + + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_num_output; qq++) + { + int q = qq * 4; + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + const signed char* kptr = weight_data_tm.row(q / 4); + + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4); + + int32x4_t _rnn_Hx0 = vdupq_n_s32(0); + int i = 0; +#if __ARM_FEATURE_DOTPROD + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _rnn_Hx0 = vdotq_laneq_s32(_rnn_Hx0, _w0, _xi, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w1, _xi, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w2, _xi, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _xi, 3); + + kptr += 64; + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _rnn_Hx0 = vdotq_lane_s32(_rnn_Hx0, _w0, _xi, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _xi, 1); + + kptr += 32; + } + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum1); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum2); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum3); +#else + int32x4_t _sum0 = vdupq_n_s32(0); + int32x4_t _sum1 = vdupq_n_s32(0); + int32x4_t _sum2 = vdupq_n_s32(0); + int32x4_t _sum3 = vdupq_n_s32(0); + for (; i + 15 < size; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* xptr = x + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(xptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(xptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int8x16_t _xi = vld1q_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_xi)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_xi)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_xi)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_xi)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_xi)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_xi)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_xi)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < size; i += 8) + { + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _xi); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _xi); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _xi); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _xi); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum0); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum1); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum2); + _rnn_Hx0 = vaddq_s32(_rnn_Hx0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < size; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _xi = vld1_s8(x + i); + int8x16_t _w = vld1q_s8(kptr); + _rnn_Hx0 = vdotq_lane_s32(_rnn_Hx0, _w, _xi, 0); +#else + int16x4_t _xi01 = vreinterpret_s16_s8(vld1_s8(x + i)); + int8x8_t _xi0 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 0)); + int8x8_t _xi1 = vreinterpret_s8_s16(vdup_lane_s16(_xi01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _rnn_Hx = vmull_s8(vget_low_s8(_w01), _xi0); + _rnn_Hx = vmlal_s8(_rnn_Hx, vget_high_s8(_w01), _xi1); + _rnn_Hx0 = vpadalq_s16(_rnn_Hx0, _rnn_Hx); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < size; i += 2) + { + int8x8_t _xi = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(x + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _rnn_Hx = vmull_s8(_w, _xi); + _rnn_Hx0 = vpadalq_s16(_rnn_Hx0, _rnn_Hx); + + kptr += 8; + } + for (; i < size; i++) + { + int8x8_t _xi = vdup_n_s8(x[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _rnn_Hx = vmull_s8(_w, _xi); + _rnn_Hx0 = vaddw_s16(_rnn_Hx0, vget_low_s16(_rnn_Hx)); + + kptr += 4; + } + + int32x4_t _rnn_Hh0 = vdupq_n_s32(0); + i = 0; +#if __ARM_FEATURE_DOTPROD + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + _rnn_Hh0 = vdotq_laneq_s32(_rnn_Hh0, _w0, _h_cont, 0); + _sum1 = vdotq_laneq_s32(_sum1, _w1, _h_cont, 1); + _sum2 = vdotq_laneq_s32(_sum2, _w2, _h_cont, 2); + _sum3 = vdotq_laneq_s32(_sum3, _w3, _h_cont, 3); + + kptr += 64; + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + _rnn_Hh0 = vdotq_lane_s32(_rnn_Hh0, _w0, _h_cont, 0); + _sum1 = vdotq_lane_s32(_sum1, _w1, _h_cont, 1); + + kptr += 32; + } + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum1); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum2); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum3); +#else + _sum0 = vdupq_n_s32(0); + _sum1 = vdupq_n_s32(0); + _sum2 = vdupq_n_s32(0); + _sum3 = vdupq_n_s32(0); + for (; i + 15 < num_output; i += 16) + { +#if NCNN_GNU_INLINE_ASM && !__aarch64__ + const signed char* hsptr = hs + i; + + asm volatile( + "vldm %1!, {d0-d7} \n" + "vld1.s8 {d16-d17}, [%0] \n" + "vmull.s8 q4, d0, d16 \n" + "vmull.s8 q5, d1, d16 \n" + "vmull.s8 q6, d2, d16 \n" + "vmull.s8 q7, d3, d16 \n" + "vmlal.s8 q4, d4, d17 \n" + "vmlal.s8 q5, d5, d17 \n" + "vmlal.s8 q6, d6, d17 \n" + "vmlal.s8 q7, d7, d17 \n" + "vpadal.s16 %q2, q4 \n" + "vpadal.s16 %q3, q5 \n" + "vpadal.s16 %q4, q6 \n" + "vpadal.s16 %q5, q7 \n" + : "=r"(hsptr), "=r"(kptr), "=w"(_sum0), "=w"(_sum1), "=w"(_sum2), "=w"(_sum3) + : "0"(hsptr), "1"(kptr), "2"(_sum0), "3"(_sum1), "4"(_sum2), "5"(_sum3) + : "memory", "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8"); +#else + int8x16_t _h_cont = vld1q_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + int8x16_t _w2 = vld1q_s8(kptr + 32); + int8x16_t _w3 = vld1q_s8(kptr + 48); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), vget_low_s8(_h_cont)); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), vget_low_s8(_h_cont)); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), vget_low_s8(_h_cont)); + _s0 = vmlal_s8(_s0, vget_low_s8(_w2), vget_high_s8(_h_cont)); + _s1 = vmlal_s8(_s1, vget_high_s8(_w2), vget_high_s8(_h_cont)); + _s2 = vmlal_s8(_s2, vget_low_s8(_w3), vget_high_s8(_h_cont)); + _s3 = vmlal_s8(_s3, vget_high_s8(_w3), vget_high_s8(_h_cont)); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 64; +#endif + } + for (; i + 7 < num_output; i += 8) + { + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w0 = vld1q_s8(kptr); + int8x16_t _w1 = vld1q_s8(kptr + 16); + + int16x8_t _s0 = vmull_s8(vget_low_s8(_w0), _h_cont); + int16x8_t _s1 = vmull_s8(vget_high_s8(_w0), _h_cont); + int16x8_t _s2 = vmull_s8(vget_low_s8(_w1), _h_cont); + int16x8_t _s3 = vmull_s8(vget_high_s8(_w1), _h_cont); + _sum0 = vpadalq_s16(_sum0, _s0); + _sum1 = vpadalq_s16(_sum1, _s1); + _sum2 = vpadalq_s16(_sum2, _s2); + _sum3 = vpadalq_s16(_sum3, _s3); + + kptr += 32; + } + { + int32x4x2_t _tmp0 = vzipq_s32(_sum0, _sum1); + int32x4x2_t _tmp1 = vzipq_s32(_sum2, _sum3); + _sum0 = vcombine_s32(vget_low_s32(_tmp0.val[0]), vget_low_s32(_tmp1.val[0])); + _sum1 = vcombine_s32(vget_high_s32(_tmp0.val[0]), vget_high_s32(_tmp1.val[0])); + _sum2 = vcombine_s32(vget_low_s32(_tmp0.val[1]), vget_low_s32(_tmp1.val[1])); + _sum3 = vcombine_s32(vget_high_s32(_tmp0.val[1]), vget_high_s32(_tmp1.val[1])); + } + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum0); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum1); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum2); + _rnn_Hh0 = vaddq_s32(_rnn_Hh0, _sum3); +#endif // __ARM_FEATURE_DOTPROD + for (; i + 3 < num_output; i += 4) + { +#if __ARM_FEATURE_DOTPROD + int8x8_t _h_cont = vld1_s8(hs + i); + int8x16_t _w = vld1q_s8(kptr); + _rnn_Hh0 = vdotq_lane_s32(_rnn_Hh0, _w, _h_cont, 0); +#else + int16x4_t _h_cont01 = vreinterpret_s16_s8(vld1_s8(hs + i)); + int8x8_t _h_cont0 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 0)); + int8x8_t _h_cont1 = vreinterpret_s8_s16(vdup_lane_s16(_h_cont01, 1)); + int8x16_t _w01 = vld1q_s8(kptr); + + int16x8_t _rnn_Hh = vmull_s8(vget_low_s8(_w01), _h_cont0); + _rnn_Hh = vmlal_s8(_rnn_Hh, vget_high_s8(_w01), _h_cont1); + _rnn_Hh0 = vpadalq_s16(_rnn_Hh0, _rnn_Hh); +#endif // __ARM_FEATURE_DOTPROD + + kptr += 16; + } + for (; i + 1 < num_output; i += 2) + { + int8x8_t _h_cont = vreinterpret_s8_s16(vdup_lane_s16(vreinterpret_s16_s8(vld1_s8(hs + i)), 0)); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _rnn_Hh = vmull_s8(_w, _h_cont); + _rnn_Hh0 = vpadalq_s16(_rnn_Hh0, _rnn_Hh); + + kptr += 8; + } + for (; i < num_output; i++) + { + int8x8_t _h_cont = vdup_n_s8(hs[i]); + int8x8_t _w = vld1_s8(kptr); + + int16x8_t _rnn_Hh = vmull_s8(_w, _h_cont); + _rnn_Hh0 = vaddw_s16(_rnn_Hh0, vget_low_s16(_rnn_Hh)); + + kptr += 4; + } + + float32x4_t _descale_x = vdupq_n_f32(descale_x); + float32x4_t _descale_h = vdupq_n_f32(descale_h); + + float32x4_t _rnn_H = vld1q_f32((const float*)bias_c + q); + + float32x4_t _descale_xc = vld1q_f32(descales_ptr); + + _rnn_H = vmlaq_f32(_rnn_H, vcvtq_f32_s32(_rnn_Hx0), vmulq_f32(_descale_x, _descale_xc)); + + float32x4_t _descale_hc = vld1q_f32(descales_ptr + 4); + + _rnn_H = vmlaq_f32(_rnn_H, vcvtq_f32_s32(_rnn_Hh0), vmulq_f32(_descale_h, _descale_hc)); + + _rnn_H = tanh_ps(_rnn_H); + + vst1q_f32((float*)gates + q, _rnn_H); + } +#endif // __ARM_NEON + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + +#if __ARM_NEON + const signed char* kptr = weight_data_tm.row(q / 4 + q % 4); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4 + q % 4); +#else + const signed char* kptr = weight_data_tm.row(q); + const float* descales_ptr = weight_data_tm_int8_descales.row(q); +#endif // __ARM_NEON + + const float descale_xc = descales_ptr[0]; + const float descale_hc = descales_ptr[1]; + + int Hx = 0; + for (int i = 0; i < size; i++) + { + Hx += kptr[0] * x[i]; + kptr += 1; + } + + int Hh = 0; + for (int i = 0; i < num_output; i++) + { + Hh += kptr[0] * hs[i]; + kptr += 1; + } + + float H = bias_c[q] + Hx * (descale_x * descale_xc) + Hh * (descale_h * descale_hc); + + H = tanhf(H); + + gates[q] = H; + } + + rnn_int8_gate_output(gates, hidden_state, top_blob, ti, elemtype, opt); + } +} diff --git a/src/layer/gru.cpp b/src/layer/gru.cpp index b1ef2e0da45..6da1f715d7a 100644 --- a/src/layer/gru.cpp +++ b/src/layer/gru.cpp @@ -27,6 +27,16 @@ int GRU::load_param(const ParamDict& pd) num_output = pd.get(0, 0); weight_data_size = pd.get(1, 0); direction = pd.get(2, 0); + int8_scale_term = pd.get(8, 0); + + if (int8_scale_term) + { +#if !NCNN_INT8 + NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference"); + return -1; +#endif + } + return 0; } @@ -49,6 +59,14 @@ int GRU::load_model(const ModelBin& mb) if (weight_hc_data.empty()) return -100; +#if NCNN_INT8 + if (int8_scale_term) + { + weight_xc_data_int8_scales = mb.load(num_output * 3, num_directions, 1); + weight_hc_data_int8_scales = mb.load(num_output * 3, num_directions, 1); + } +#endif // NCNN_INT8 + return 0; } @@ -160,6 +178,182 @@ static int gru(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we return 0; } +#if NCNN_INT8 +static int gru_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, Mat& hidden_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // 2 x num_output + Mat gates(2, num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator); + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator); + { + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(x[i])); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + } + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator); + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8_scales[0] = 1.f; + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scales[0] = 127.f / absmax; + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant); + } + } + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = 1.f / bottom_blob_int8_scales[ti]; + const float descale_h = 1.f / hidden_state_int8_scales[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + float* gates_data = gates.row(q); + + // gate reset update + const float* bias_c_R = bias_c.row(0); + const float* bias_c_U = bias_c.row(1); + + const signed char* weight_xc_int8_R = weight_xc_int8.row(num_output * 0 + q); + const signed char* weight_xc_int8_U = weight_xc_int8.row(num_output * 1 + q); + const signed char* weight_hc_int8_R = weight_hc_int8.row(num_output * 0 + q); + const signed char* weight_hc_int8_U = weight_hc_int8.row(num_output * 1 + q); + + const float descale_xc_R = 1.f / weight_xc_int8_scales[num_output * 0 + q]; + const float descale_xc_U = 1.f / weight_xc_int8_scales[num_output * 1 + q]; + const float descale_hc_R = 1.f / weight_hc_int8_scales[num_output * 0 + q]; + const float descale_hc_U = 1.f / weight_hc_int8_scales[num_output * 1 + q]; + + int Rx = 0; + int Ux = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Rx += weight_xc_int8_R[i] * xi; + Ux += weight_xc_int8_U[i] * xi; + } + + int Rh = 0; + int Uh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Rh += weight_hc_int8_R[i] * h_cont; + Uh += weight_hc_int8_U[i] * h_cont; + } + + float R = bias_c_R[q] + Rx * (descale_x * descale_xc_R) + Rh * (descale_h * descale_hc_R); + float U = bias_c_U[q] + Ux * (descale_x * descale_xc_U) + Uh * (descale_h * descale_hc_U); + + // sigmoid(R) + // sigmoid(U) + R = 1.f / (1.f + expf(-R)); + U = 1.f / (1.f + expf(-U)); + + // gate new + const float* bias_c_WN = bias_c.row(2); + const float* bias_c_BN = bias_c.row(3); + + const signed char* weight_xc_int8_N = weight_xc_int8.row(num_output * 2 + q); + const signed char* weight_hc_int8_N = weight_hc_int8.row(num_output * 2 + q); + + const float descale_xc_N = 1.f / weight_xc_int8_scales[num_output * 2 + q]; + const float descale_hc_N = 1.f / weight_hc_int8_scales[num_output * 2 + q]; + + int Nh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Nh += weight_hc_int8_N[i] * h_cont; + } + + int Nx = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Nx += weight_xc_int8_N[i] * xi; + } + + float N = bias_c_BN[q] + Nh * (descale_h * descale_hc_N); + N = bias_c_WN[q] + R * N + Nx * (descale_x * descale_xc_N); + + // tanh(N) + N = tanhf(N); + + gates_data[0] = U; + gates_data[1] = N; + } + + // h_t := (1 - update) .* new + update .* h_{t-1} + float* output_data = top_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + const float* gates_data = gates.row(q); + + float U = gates_data[0]; + float N = gates_data[1]; + + float H = (1 - U) * N + U * hidden_state[q]; + + hidden_state[q] = H; + output_data[q] = H; + } + } + + return 0; +} +#endif // NCNN_INT8 + int GRU::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int T = bottom_blob.h; @@ -179,9 +373,20 @@ int GRU::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const // Uni directional if (direction == 0 || direction == 1) { - int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -194,15 +399,37 @@ int GRU::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const if (top_blob_reverse.empty()) return -100; - int ret0 = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); - int ret1 = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -247,9 +474,20 @@ int GRU::forward(const std::vector& bottom_blobs, std::vector& top_blo // Uni directional if (direction == 0 || direction == 1) { - int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = gru_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = gru(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -263,14 +501,36 @@ int GRU::forward(const std::vector& bottom_blobs, std::vector& top_blo return -100; Mat hidden0 = hidden.row_range(0, 1); - int ret0 = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = gru_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = gru(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); - int ret1 = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = gru_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = gru(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) diff --git a/src/layer/gru.h b/src/layer/gru.h index 976550f7722..1f9d73cf7db 100644 --- a/src/layer/gru.h +++ b/src/layer/gru.h @@ -37,9 +37,16 @@ class GRU : public Layer int weight_data_size; int direction; // 0=forward 1=reverse 2=bidirectional + int int8_scale_term; + Mat weight_hc_data; Mat weight_xc_data; Mat bias_c_data; + +#if NCNN_INT8 + Mat weight_hc_data_int8_scales; + Mat weight_xc_data_int8_scales; +#endif }; } // namespace ncnn diff --git a/src/layer/lstm.cpp b/src/layer/lstm.cpp index c761a98d4dd..53f2ac25cf9 100644 --- a/src/layer/lstm.cpp +++ b/src/layer/lstm.cpp @@ -28,6 +28,16 @@ int LSTM::load_param(const ParamDict& pd) weight_data_size = pd.get(1, 0); direction = pd.get(2, 0); hidden_size = pd.get(3, num_output); + int8_scale_term = pd.get(8, 0); + + if (int8_scale_term) + { +#if !NCNN_INT8 + NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference"); + return -1; +#endif + } + return 0; } @@ -57,6 +67,14 @@ int LSTM::load_model(const ModelBin& mb) return -100; } +#if NCNN_INT8 + if (int8_scale_term) + { + weight_xc_data_int8_scales = mb.load(hidden_size * 4, num_directions, 1); + weight_hc_data_int8_scales = mb.load(hidden_size * 4, num_directions, 1); + } +#endif // NCNN_INT8 + return 0; } @@ -206,6 +224,224 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w return 0; } +#if NCNN_INT8 +static int lstm_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + if (tmp_hidden_state.empty()) + return -100; + } + + // dynamic quantize bottom_blob + Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator); + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator); + { + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(x[i])); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + } + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator); + + // unroll + for (int t = 0; t < T; t++) + { + // clip hidden by continuation indicator + // h_cont_{t-1} = cont_t * h_{t-1} + // h_cont_{t-1} = h_{t-1} if cont_t == 1 + // 0 otherwise + // calculate hidden + // gate_input_t := W_hc * h_conted_{t-1} + W_xc * x_t + b_c + + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8_scales[0] = 1.f; + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scales[0] = 127.f / absmax; + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant); + } + } + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = 1.f / bottom_blob_int8_scales[ti]; + const float descale_h = 1.f / hidden_state_int8_scales[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const float* bias_c_I = bias_c.row(0); + const float* bias_c_F = bias_c.row(1); + const float* bias_c_O = bias_c.row(2); + const float* bias_c_G = bias_c.row(3); + + float* gates_data = gates.row(q); + + // gate I F O G + const signed char* weight_xc_int8_I = weight_xc_int8.row(hidden_size * 0 + q); + const signed char* weight_xc_int8_F = weight_xc_int8.row(hidden_size * 1 + q); + const signed char* weight_xc_int8_O = weight_xc_int8.row(hidden_size * 2 + q); + const signed char* weight_xc_int8_G = weight_xc_int8.row(hidden_size * 3 + q); + + const signed char* weight_hc_int8_I = weight_hc_int8.row(hidden_size * 0 + q); + const signed char* weight_hc_int8_F = weight_hc_int8.row(hidden_size * 1 + q); + const signed char* weight_hc_int8_O = weight_hc_int8.row(hidden_size * 2 + q); + const signed char* weight_hc_int8_G = weight_hc_int8.row(hidden_size * 3 + q); + + const float descale_xc_I = 1.f / weight_xc_int8_scales[hidden_size * 0 + q]; + const float descale_xc_F = 1.f / weight_xc_int8_scales[hidden_size * 1 + q]; + const float descale_xc_O = 1.f / weight_xc_int8_scales[hidden_size * 2 + q]; + const float descale_xc_G = 1.f / weight_xc_int8_scales[hidden_size * 3 + q]; + const float descale_hc_I = 1.f / weight_hc_int8_scales[hidden_size * 0 + q]; + const float descale_hc_F = 1.f / weight_hc_int8_scales[hidden_size * 1 + q]; + const float descale_hc_O = 1.f / weight_hc_int8_scales[hidden_size * 2 + q]; + const float descale_hc_G = 1.f / weight_hc_int8_scales[hidden_size * 3 + q]; + + int Ix = 0; + int Fx = 0; + int Ox = 0; + int Gx = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Ix += weight_xc_int8_I[i] * xi; + Fx += weight_xc_int8_F[i] * xi; + Ox += weight_xc_int8_O[i] * xi; + Gx += weight_xc_int8_G[i] * xi; + } + + int Ih = 0; + int Fh = 0; + int Oh = 0; + int Gh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Ih += weight_hc_int8_I[i] * h_cont; + Fh += weight_hc_int8_F[i] * h_cont; + Oh += weight_hc_int8_O[i] * h_cont; + Gh += weight_hc_int8_G[i] * h_cont; + } + + float I = bias_c_I[q] + Ix * (descale_x * descale_xc_I) + Ih * (descale_h * descale_hc_I); + float F = bias_c_F[q] + Fx * (descale_x * descale_xc_F) + Fh * (descale_h * descale_hc_F); + float O = bias_c_O[q] + Ox * (descale_x * descale_xc_O) + Oh * (descale_h * descale_hc_O); + float G = bias_c_G[q] + Gx * (descale_x * descale_xc_G) + Gh * (descale_h * descale_hc_G); + + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + float* output_data = top_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_state[q] + I * G; + float H = O * tanhf(cell2); + cell_state[q] = cell2; + + if (num_output == hidden_size) + { + hidden_state[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_state[q] = H; + } + } + + if (num_output != hidden_size) + { + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_state[i] * hr[i]; + } + + hidden_state[q] = H; + output_data[q] = H; + } + } + } + + return 0; +} +#endif // NCNN_INT8 + int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int T = bottom_blob.h; @@ -230,9 +466,20 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -245,16 +492,38 @@ int LSTM::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) cons if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -306,9 +575,20 @@ int LSTM::forward(const std::vector& bottom_blobs, std::vector& top_bl // Uni directional if (direction == 0 || direction == 1) { - int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -323,15 +603,37 @@ int LSTM::forward(const std::vector& bottom_blobs, std::vector& top_bl Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = lstm_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) diff --git a/src/layer/lstm.h b/src/layer/lstm.h index 58bd67f987a..5e4938d7bf3 100644 --- a/src/layer/lstm.h +++ b/src/layer/lstm.h @@ -38,10 +38,17 @@ class LSTM : public Layer int direction; // 0=forward 1=reverse 2=bidirectional int hidden_size; + int int8_scale_term; + Mat weight_hc_data; Mat weight_xc_data; Mat bias_c_data; Mat weight_hr_data; + +#if NCNN_INT8 + Mat weight_hc_data_int8_scales; + Mat weight_xc_data_int8_scales; +#endif }; } // namespace ncnn diff --git a/src/layer/riscv/gru_riscv.cpp b/src/layer/riscv/gru_riscv.cpp index 0869a455979..25218ddc32e 100644 --- a/src/layer/riscv/gru_riscv.cpp +++ b/src/layer/riscv/gru_riscv.cpp @@ -215,6 +215,14 @@ GRU_riscv::GRU_riscv() int GRU_riscv::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + support_fp16_storage = false; + return 0; + } +#endif + #if __riscv_vector && __riscv_zfh if (opt.use_fp16_storage && opt.use_fp16_arithmetic) return create_pipeline_fp16sa(opt); @@ -225,6 +233,13 @@ int GRU_riscv::create_pipeline(const Option& opt) int GRU_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return GRU::forward(bottom_blob, top_blob, opt); + } +#endif + int elembits = bottom_blob.elembits(); #if __riscv_vector @@ -299,6 +314,13 @@ int GRU_riscv::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int GRU_riscv::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return GRU::forward(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& bottom_blob = bottom_blobs[0]; int elembits = bottom_blob.elembits(); diff --git a/src/layer/rnn.cpp b/src/layer/rnn.cpp index 6cc8ba5c9bd..49af3b300f6 100644 --- a/src/layer/rnn.cpp +++ b/src/layer/rnn.cpp @@ -27,6 +27,16 @@ int RNN::load_param(const ParamDict& pd) num_output = pd.get(0, 0); weight_data_size = pd.get(1, 0); direction = pd.get(2, 0); + int8_scale_term = pd.get(8, 0); + + if (int8_scale_term) + { +#if !NCNN_INT8 + NCNN_LOGE("please build ncnn with NCNN_INT8 enabled for int8 inference"); + return -1; +#endif + } + return 0; } @@ -49,6 +59,14 @@ int RNN::load_model(const ModelBin& mb) if (weight_hc_data.empty()) return -100; +#if NCNN_INT8 + if (int8_scale_term) + { + weight_xc_data_int8_scales = mb.load(num_output, num_directions, 1); + weight_hc_data_int8_scales = mb.load(num_output, num_directions, 1); + } +#endif // NCNN_INT8 + return 0; } @@ -107,6 +125,121 @@ static int rnn(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& we return 0; } +#if NCNN_INT8 +static int rnn_int8(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& weight_xc_int8, const float* weight_xc_int8_scales, const Mat& bias_c, const Mat& weight_hc_int8, const float* weight_hc_int8_scales, Mat& hidden_state, const Option& opt) +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + int num_output = top_blob.w; + + // num_output + Mat gates(num_output, 4u, opt.workspace_allocator); + if (gates.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8(size, T, (size_t)1u, 1, opt.workspace_allocator); + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.workspace_allocator); + { + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(x[i])); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + } + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt_quant); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + Mat hidden_state_int8_scales(1, (size_t)4u, 1, opt.workspace_allocator); + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8_scales[0] = 1.f; + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scales[0] = 127.f / absmax; + + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + quantize_to_int8(hidden_state, hidden_state_int8, hidden_state_int8_scales, opt_quant); + } + } + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = 1.f / bottom_blob_int8_scales[ti]; + const float descale_h = 1.f / hidden_state_int8_scales[0]; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + const signed char* weight_xc_int8_ptr = weight_xc_int8.row(q); + const signed char* weight_hc_int8_ptr = weight_hc_int8.row(q); + + const float descale_xc = 1.f / weight_xc_int8_scales[q]; + const float descale_hc = 1.f / weight_hc_int8_scales[q]; + + int Hx = 0; + for (int i = 0; i < size; i++) + { + Hx += weight_xc_int8_ptr[i] * x[i]; + } + + int Hh = 0; + for (int i = 0; i < num_output; i++) + { + Hh += weight_hc_int8_ptr[i] * hs[i]; + } + + float H = bias_c[q] + Hx * (descale_x * descale_xc) + Hh * (descale_h * descale_hc); + + H = tanhf(H); + + gates[q] = H; + } + + float* output_data = top_blob.row(ti); + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = 0; q < num_output; q++) + { + float H = gates[q]; + + hidden_state[q] = H; + output_data[q] = H; + } + } + + return 0; +} +#endif // NCNN_INT8 + int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { int T = bottom_blob.h; @@ -126,9 +259,20 @@ int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -141,15 +285,37 @@ int RNN::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const if (top_blob_reverse.empty()) return -100; - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -194,9 +360,20 @@ int RNN::forward(const std::vector& bottom_blobs, std::vector& top_blo // Uni directional if (direction == 0 || direction == 1) { - int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); - if (ret != 0) - return ret; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob, direction, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob, direction, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden, opt); + if (ret != 0) + return ret; + } } if (direction == 2) @@ -210,14 +387,36 @@ int RNN::forward(const std::vector& bottom_blobs, std::vector& top_blo return -100; Mat hidden0 = hidden.row_range(0, 1); - int ret0 = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); - if (ret0 != 0) - return ret0; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), weight_xc_data_int8_scales.row(0), bias_c_data.channel(0), weight_hc_data.channel(0), weight_hc_data_int8_scales.row(0), hidden0, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_forward, 0, weight_xc_data.channel(0), bias_c_data.channel(0), weight_hc_data.channel(0), hidden0, opt); + if (ret != 0) + return ret; + } Mat hidden1 = hidden.row_range(1, 1); - int ret1 = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); - if (ret1 != 0) - return ret1; +#if NCNN_INT8 + if (int8_scale_term) + { + int ret = rnn_int8(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), weight_xc_data_int8_scales.row(1), bias_c_data.channel(1), weight_hc_data.channel(1), weight_hc_data_int8_scales.row(1), hidden1, opt); + if (ret != 0) + return ret; + } + else +#endif + { + int ret = rnn(bottom_blob, top_blob_reverse, 1, weight_xc_data.channel(1), bias_c_data.channel(1), weight_hc_data.channel(1), hidden1, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) diff --git a/src/layer/rnn.h b/src/layer/rnn.h index c52a920811f..7a99ef51bbb 100644 --- a/src/layer/rnn.h +++ b/src/layer/rnn.h @@ -37,9 +37,16 @@ class RNN : public Layer int weight_data_size; int direction; // 0=forward 1=reverse 2=bidirectional + int int8_scale_term; + Mat weight_hc_data; Mat weight_xc_data; Mat bias_c_data; + +#if NCNN_INT8 + Mat weight_hc_data_int8_scales; + Mat weight_xc_data_int8_scales; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/lstm_int8.h b/src/layer/x86/lstm_int8.h new file mode 100644 index 00000000000..c5e8b06f259 --- /dev/null +++ b/src/layer/x86/lstm_int8.h @@ -0,0 +1,3163 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ +void lstm_transform_weight_int8_avx512vnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt); +void lstm_int8_avx512vnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ +void lstm_transform_weight_int8_avxvnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt); +void lstm_int8_avxvnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +void lstm_transform_weight_int8_avx2(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt); +void lstm_int8_avx2(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ +void lstm_int8_xop(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt); +#endif + +static void lstm_transform_weight_int8(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + lstm_transform_weight_int8_avx512vnni(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + lstm_transform_weight_int8_avxvnni(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + lstm_transform_weight_int8_avx2(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); + return; + } +#endif + +#if __AVX512F__ +#if __AVX512VNNI__ + weight_data_tm.create(size + 4 + num_output + 4, hidden_size / 4 + hidden_size % 4, num_directions, 16u, 16); +#else + weight_data_tm.create(size + num_output, hidden_size / 4 + hidden_size % 4, num_directions, 16u, 16); +#endif + weight_data_tm_int8_descales.create(16 + 16, hidden_size / 4 + hidden_size % 4, num_directions); +#elif __AVX2__ +#if __AVXVNNI__ + weight_data_tm.create(size + 4 + num_output + 4, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); +#else + weight_data_tm.create(size + num_output, hidden_size / 2 + hidden_size % 2, num_directions, 8u, 8); +#endif + weight_data_tm_int8_descales.create(8 + 8, hidden_size / 2 + hidden_size % 2, num_directions); +#else + weight_data_tm.create(size + num_output, hidden_size, num_directions, 4u, 4); + weight_data_tm_int8_descales.create(4 + 4, hidden_size, num_directions); +#endif + bias_c_tm.create(hidden_size, 1, num_directions, 16u, 4); + + #pragma omp parallel for num_threads(opt.num_threads) + for (int dr = 0; dr < num_directions; dr++) + { + const Mat weight_xc_dr = weight_xc.channel(dr); + const Mat weight_hc_dr = weight_hc.channel(dr); + const Mat bias_c_dr = bias_c.channel(dr); + const float* weight_xc_int8_scales_ptr = weight_xc_int8_scales.row(dr); + const float* weight_hc_int8_scales_ptr = weight_hc_int8_scales.row(dr); + + Mat weight_data_tm_dr = weight_data_tm.channel(dr); + Mat bias_c_tm_dr = bias_c_tm.channel(dr); + Mat weight_data_tm_int8_descales_dr = weight_data_tm_int8_descales.channel(dr); + + const float* bias_c_I = bias_c_dr.row(0); + const float* bias_c_F = bias_c_dr.row(1); + const float* bias_c_O = bias_c_dr.row(2); + const float* bias_c_G = bias_c_dr.row(3); + + float* bias_c_IFOG = bias_c_tm_dr.row(0); + + int q = 0; +#if __AVX2__ +#if __AVX512F__ + for (; q + 3 < hidden_size; q += 4) + { + _mm_storeu_ps(bias_c_IFOG, _mm_loadu_ps(bias_c_I + q)); + _mm_storeu_ps(bias_c_IFOG + 4, _mm_loadu_ps(bias_c_F + q)); + _mm_storeu_ps(bias_c_IFOG + 8, _mm_loadu_ps(bias_c_O + q)); + _mm_storeu_ps(bias_c_IFOG + 12, _mm_loadu_ps(bias_c_G + q)); + bias_c_IFOG += 16; + + const signed char* weight_xc_I_0 = weight_xc_dr.row(hidden_size * 0 + q); + const signed char* weight_xc_F_0 = weight_xc_dr.row(hidden_size * 1 + q); + const signed char* weight_xc_O_0 = weight_xc_dr.row(hidden_size * 2 + q); + const signed char* weight_xc_G_0 = weight_xc_dr.row(hidden_size * 3 + q); + const signed char* weight_xc_I_1 = weight_xc_dr.row(hidden_size * 0 + q + 1); + const signed char* weight_xc_F_1 = weight_xc_dr.row(hidden_size * 1 + q + 1); + const signed char* weight_xc_O_1 = weight_xc_dr.row(hidden_size * 2 + q + 1); + const signed char* weight_xc_G_1 = weight_xc_dr.row(hidden_size * 3 + q + 1); + const signed char* weight_xc_I_2 = weight_xc_dr.row(hidden_size * 0 + q + 2); + const signed char* weight_xc_F_2 = weight_xc_dr.row(hidden_size * 1 + q + 2); + const signed char* weight_xc_O_2 = weight_xc_dr.row(hidden_size * 2 + q + 2); + const signed char* weight_xc_G_2 = weight_xc_dr.row(hidden_size * 3 + q + 2); + const signed char* weight_xc_I_3 = weight_xc_dr.row(hidden_size * 0 + q + 3); + const signed char* weight_xc_F_3 = weight_xc_dr.row(hidden_size * 1 + q + 3); + const signed char* weight_xc_O_3 = weight_xc_dr.row(hidden_size * 2 + q + 3); + const signed char* weight_xc_G_3 = weight_xc_dr.row(hidden_size * 3 + q + 3); + + const signed char* weight_hc_I_0 = weight_hc_dr.row(hidden_size * 0 + q); + const signed char* weight_hc_F_0 = weight_hc_dr.row(hidden_size * 1 + q); + const signed char* weight_hc_O_0 = weight_hc_dr.row(hidden_size * 2 + q); + const signed char* weight_hc_G_0 = weight_hc_dr.row(hidden_size * 3 + q); + const signed char* weight_hc_I_1 = weight_hc_dr.row(hidden_size * 0 + q + 1); + const signed char* weight_hc_F_1 = weight_hc_dr.row(hidden_size * 1 + q + 1); + const signed char* weight_hc_O_1 = weight_hc_dr.row(hidden_size * 2 + q + 1); + const signed char* weight_hc_G_1 = weight_hc_dr.row(hidden_size * 3 + q + 1); + const signed char* weight_hc_I_2 = weight_hc_dr.row(hidden_size * 0 + q + 2); + const signed char* weight_hc_F_2 = weight_hc_dr.row(hidden_size * 1 + q + 2); + const signed char* weight_hc_O_2 = weight_hc_dr.row(hidden_size * 2 + q + 2); + const signed char* weight_hc_G_2 = weight_hc_dr.row(hidden_size * 3 + q + 2); + const signed char* weight_hc_I_3 = weight_hc_dr.row(hidden_size * 0 + q + 3); + const signed char* weight_hc_F_3 = weight_hc_dr.row(hidden_size * 1 + q + 3); + const signed char* weight_hc_O_3 = weight_hc_dr.row(hidden_size * 2 + q + 3); + const signed char* weight_hc_G_3 = weight_hc_dr.row(hidden_size * 3 + q + 3); + + signed char* kptr = weight_data_tm_dr.row(q / 4); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4); + + int i = 0; +#if __AVX512VNNI__ + __m512i _w_shift = _mm512_setzero_si512(); + __m512i _v127 = _mm512_set1_epi8(127); + + __m512i _w0_shift = _mm512_setzero_si512(); + __m512i _w1_shift = _mm512_setzero_si512(); +#if defined(__x86_64__) || defined(_M_X64) + __m512i _w2_shift = _mm512_setzero_si512(); + __m512i _w3_shift = _mm512_setzero_si512(); + for (; i + 15 < size; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_xc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_xc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_xc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_xc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_xc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_xc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_xc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_xc_G_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 8), _mm_loadu_si128((const __m128i*)(weight_xc_I_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 9), _mm_loadu_si128((const __m128i*)(weight_xc_F_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 10), _mm_loadu_si128((const __m128i*)(weight_xc_O_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 11), _mm_loadu_si128((const __m128i*)(weight_xc_G_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 12), _mm_loadu_si128((const __m128i*)(weight_xc_I_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 13), _mm_loadu_si128((const __m128i*)(weight_xc_F_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 14), _mm_loadu_si128((const __m128i*)(weight_xc_O_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 15), _mm_loadu_si128((const __m128i*)(weight_xc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); + __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm512_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm512_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 256; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_w0_shift, _w1_shift); + __m512i _tmp1 = _mm512_unpackhi_epi32(_w0_shift, _w1_shift); + __m512i _tmp2 = _mm512_unpacklo_epi32(_w2_shift, _w3_shift); + __m512i _tmp3 = _mm512_unpackhi_epi32(_w2_shift, _w3_shift); + _w0_shift = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _w1_shift = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _w2_shift = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _w3_shift = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _w_shift = _mm512_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm512_setzero_si512(); + _w1_shift = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_xc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_xc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_xc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_xc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_xc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_xc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_xc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(3, 1, 3, 1))); + + _w_shift = _mm512_add_epi32(_w_shift, _tmp0); + _w_shift = _mm512_add_epi32(_w_shift, _tmp1); + } + + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_I_0[i + 2]; + kptr[3] = weight_xc_I_0[i + 3]; + kptr[4] = weight_xc_I_1[i]; + kptr[5] = weight_xc_I_1[i + 1]; + kptr[6] = weight_xc_I_1[i + 2]; + kptr[7] = weight_xc_I_1[i + 3]; + kptr[8 + 0] = weight_xc_I_2[i]; + kptr[8 + 1] = weight_xc_I_2[i + 1]; + kptr[8 + 2] = weight_xc_I_2[i + 2]; + kptr[8 + 3] = weight_xc_I_2[i + 3]; + kptr[8 + 4] = weight_xc_I_3[i]; + kptr[8 + 5] = weight_xc_I_3[i + 1]; + kptr[8 + 6] = weight_xc_I_3[i + 2]; + kptr[8 + 7] = weight_xc_I_3[i + 3]; + kptr[16 + 0] = weight_xc_F_0[i]; + kptr[16 + 1] = weight_xc_F_0[i + 1]; + kptr[16 + 2] = weight_xc_F_0[i + 2]; + kptr[16 + 3] = weight_xc_F_0[i + 3]; + kptr[16 + 4] = weight_xc_F_1[i]; + kptr[16 + 5] = weight_xc_F_1[i + 1]; + kptr[16 + 6] = weight_xc_F_1[i + 2]; + kptr[16 + 7] = weight_xc_F_1[i + 3]; + kptr[24 + 0] = weight_xc_F_2[i]; + kptr[24 + 1] = weight_xc_F_2[i + 1]; + kptr[24 + 2] = weight_xc_F_2[i + 2]; + kptr[24 + 3] = weight_xc_F_2[i + 3]; + kptr[24 + 4] = weight_xc_F_3[i]; + kptr[24 + 5] = weight_xc_F_3[i + 1]; + kptr[24 + 6] = weight_xc_F_3[i + 2]; + kptr[24 + 7] = weight_xc_F_3[i + 3]; + kptr[32 + 0] = weight_xc_O_0[i]; + kptr[32 + 1] = weight_xc_O_0[i + 1]; + kptr[32 + 2] = weight_xc_O_0[i + 2]; + kptr[32 + 3] = weight_xc_O_0[i + 3]; + kptr[32 + 4] = weight_xc_O_1[i]; + kptr[32 + 5] = weight_xc_O_1[i + 1]; + kptr[32 + 6] = weight_xc_O_1[i + 2]; + kptr[32 + 7] = weight_xc_O_1[i + 3]; + kptr[40 + 0] = weight_xc_O_2[i]; + kptr[40 + 1] = weight_xc_O_2[i + 1]; + kptr[40 + 2] = weight_xc_O_2[i + 2]; + kptr[40 + 3] = weight_xc_O_2[i + 3]; + kptr[40 + 4] = weight_xc_O_3[i]; + kptr[40 + 5] = weight_xc_O_3[i + 1]; + kptr[40 + 6] = weight_xc_O_3[i + 2]; + kptr[40 + 7] = weight_xc_O_3[i + 3]; + kptr[48 + 0] = weight_xc_G_0[i]; + kptr[48 + 1] = weight_xc_G_0[i + 1]; + kptr[48 + 2] = weight_xc_G_0[i + 2]; + kptr[48 + 3] = weight_xc_G_0[i + 3]; + kptr[48 + 4] = weight_xc_G_1[i]; + kptr[48 + 5] = weight_xc_G_1[i + 1]; + kptr[48 + 6] = weight_xc_G_1[i + 2]; + kptr[48 + 7] = weight_xc_G_1[i + 3]; + kptr[56 + 0] = weight_xc_G_2[i]; + kptr[56 + 1] = weight_xc_G_2[i + 1]; + kptr[56 + 2] = weight_xc_G_2[i + 2]; + kptr[56 + 3] = weight_xc_G_2[i + 3]; + kptr[56 + 4] = weight_xc_G_3[i]; + kptr[56 + 5] = weight_xc_G_3[i + 1]; + kptr[56 + 6] = weight_xc_G_3[i + 2]; + kptr[56 + 7] = weight_xc_G_3[i + 3]; + + __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _w); + + kptr += 64; + } + + _mm512_storeu_si512((__m512i*)kptr, _w_shift); + kptr += 64; +#else +#if defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_xc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_xc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_xc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_xc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_xc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_xc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_xc_G_3 + i))); + kptr += 128; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_I_0[i + 2]; + kptr[3] = weight_xc_I_0[i + 3]; + kptr[4] = weight_xc_I_1[i]; + kptr[5] = weight_xc_I_1[i + 1]; + kptr[6] = weight_xc_I_1[i + 2]; + kptr[7] = weight_xc_I_1[i + 3]; + kptr[8 + 0] = weight_xc_F_0[i]; + kptr[8 + 1] = weight_xc_F_0[i + 1]; + kptr[8 + 2] = weight_xc_F_0[i + 2]; + kptr[8 + 3] = weight_xc_F_0[i + 3]; + kptr[8 + 4] = weight_xc_F_1[i]; + kptr[8 + 5] = weight_xc_F_1[i + 1]; + kptr[8 + 6] = weight_xc_F_1[i + 2]; + kptr[8 + 7] = weight_xc_F_1[i + 3]; + kptr[16 + 0] = weight_xc_O_0[i]; + kptr[16 + 1] = weight_xc_O_0[i + 1]; + kptr[16 + 2] = weight_xc_O_0[i + 2]; + kptr[16 + 3] = weight_xc_O_0[i + 3]; + kptr[16 + 4] = weight_xc_O_1[i]; + kptr[16 + 5] = weight_xc_O_1[i + 1]; + kptr[16 + 6] = weight_xc_O_1[i + 2]; + kptr[16 + 7] = weight_xc_O_1[i + 3]; + kptr[24 + 0] = weight_xc_G_0[i]; + kptr[24 + 1] = weight_xc_G_0[i + 1]; + kptr[24 + 2] = weight_xc_G_0[i + 2]; + kptr[24 + 3] = weight_xc_G_0[i + 3]; + kptr[24 + 4] = weight_xc_G_1[i]; + kptr[24 + 5] = weight_xc_G_1[i + 1]; + kptr[24 + 6] = weight_xc_G_1[i + 2]; + kptr[24 + 7] = weight_xc_G_1[i + 3]; + kptr[32 + 0] = weight_xc_I_2[i]; + kptr[32 + 1] = weight_xc_I_2[i + 1]; + kptr[32 + 2] = weight_xc_I_2[i + 2]; + kptr[32 + 3] = weight_xc_I_2[i + 3]; + kptr[32 + 4] = weight_xc_I_3[i]; + kptr[32 + 5] = weight_xc_I_3[i + 1]; + kptr[32 + 6] = weight_xc_I_3[i + 2]; + kptr[32 + 7] = weight_xc_I_3[i + 3]; + kptr[40 + 0] = weight_xc_F_2[i]; + kptr[40 + 1] = weight_xc_F_2[i + 1]; + kptr[40 + 2] = weight_xc_F_2[i + 2]; + kptr[40 + 3] = weight_xc_F_2[i + 3]; + kptr[40 + 4] = weight_xc_F_3[i]; + kptr[40 + 5] = weight_xc_F_3[i + 1]; + kptr[40 + 6] = weight_xc_F_3[i + 2]; + kptr[40 + 7] = weight_xc_F_3[i + 3]; + kptr[48 + 0] = weight_xc_O_2[i]; + kptr[48 + 1] = weight_xc_O_2[i + 1]; + kptr[48 + 2] = weight_xc_O_2[i + 2]; + kptr[48 + 3] = weight_xc_O_2[i + 3]; + kptr[48 + 4] = weight_xc_O_3[i]; + kptr[48 + 5] = weight_xc_O_3[i + 1]; + kptr[48 + 6] = weight_xc_O_3[i + 2]; + kptr[48 + 7] = weight_xc_O_3[i + 3]; + kptr[56 + 0] = weight_xc_G_2[i]; + kptr[56 + 1] = weight_xc_G_2[i + 1]; + kptr[56 + 2] = weight_xc_G_2[i + 2]; + kptr[56 + 3] = weight_xc_G_2[i + 3]; + kptr[56 + 4] = weight_xc_G_3[i]; + kptr[56 + 5] = weight_xc_G_3[i + 1]; + kptr[56 + 6] = weight_xc_G_3[i + 2]; + kptr[56 + 7] = weight_xc_G_3[i + 3]; + kptr += 64; + } +#endif // __AVX512VNNI__ + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_I_1[i]; + kptr[3] = weight_xc_I_1[i + 1]; + kptr[4] = weight_xc_I_2[i]; + kptr[5] = weight_xc_I_2[i + 1]; + kptr[6] = weight_xc_I_3[i]; + kptr[7] = weight_xc_I_3[i + 1]; + kptr[8 + 0] = weight_xc_F_0[i]; + kptr[8 + 1] = weight_xc_F_0[i + 1]; + kptr[8 + 2] = weight_xc_F_1[i]; + kptr[8 + 3] = weight_xc_F_1[i + 1]; + kptr[8 + 4] = weight_xc_F_2[i]; + kptr[8 + 5] = weight_xc_F_2[i + 1]; + kptr[8 + 6] = weight_xc_F_3[i]; + kptr[8 + 7] = weight_xc_F_3[i + 1]; + kptr[16 + 0] = weight_xc_O_0[i]; + kptr[16 + 1] = weight_xc_O_0[i + 1]; + kptr[16 + 2] = weight_xc_O_1[i]; + kptr[16 + 3] = weight_xc_O_1[i + 1]; + kptr[16 + 4] = weight_xc_O_2[i]; + kptr[16 + 5] = weight_xc_O_2[i + 1]; + kptr[16 + 6] = weight_xc_O_3[i]; + kptr[16 + 7] = weight_xc_O_3[i + 1]; + kptr[24 + 0] = weight_xc_G_0[i]; + kptr[24 + 1] = weight_xc_G_0[i + 1]; + kptr[24 + 2] = weight_xc_G_1[i]; + kptr[24 + 3] = weight_xc_G_1[i + 1]; + kptr[24 + 4] = weight_xc_G_2[i]; + kptr[24 + 5] = weight_xc_G_2[i + 1]; + kptr[24 + 6] = weight_xc_G_3[i]; + kptr[24 + 7] = weight_xc_G_3[i + 1]; + kptr += 32; + } + for (; i < size; i++) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_1[i]; + kptr[2] = weight_xc_I_2[i]; + kptr[3] = weight_xc_I_3[i]; + kptr[4] = weight_xc_F_0[i]; + kptr[5] = weight_xc_F_1[i]; + kptr[6] = weight_xc_F_2[i]; + kptr[7] = weight_xc_F_3[i]; + kptr[8 + 0] = weight_xc_O_0[i]; + kptr[8 + 1] = weight_xc_O_1[i]; + kptr[8 + 2] = weight_xc_O_2[i]; + kptr[8 + 3] = weight_xc_O_3[i]; + kptr[8 + 4] = weight_xc_G_0[i]; + kptr[8 + 5] = weight_xc_G_1[i]; + kptr[8 + 6] = weight_xc_G_2[i]; + kptr[8 + 7] = weight_xc_G_3[i]; + kptr += 16; + } + + i = 0; +#if __AVX512VNNI__ + _w_shift = _mm512_setzero_si512(); + _w0_shift = _mm512_setzero_si512(); + _w1_shift = _mm512_setzero_si512(); +#if defined(__x86_64__) || defined(_M_X64) + _w2_shift = _mm512_setzero_si512(); + _w3_shift = _mm512_setzero_si512(); + for (; i + 15 < num_output; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_hc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_hc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_hc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_hc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_hc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_hc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_hc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_hc_G_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 8), _mm_loadu_si128((const __m128i*)(weight_hc_I_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 9), _mm_loadu_si128((const __m128i*)(weight_hc_F_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 10), _mm_loadu_si128((const __m128i*)(weight_hc_O_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 11), _mm_loadu_si128((const __m128i*)(weight_hc_G_2 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 12), _mm_loadu_si128((const __m128i*)(weight_hc_I_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 13), _mm_loadu_si128((const __m128i*)(weight_hc_F_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 14), _mm_loadu_si128((const __m128i*)(weight_hc_O_3 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 15), _mm_loadu_si128((const __m128i*)(weight_hc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); + __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm512_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm512_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 256; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_w0_shift, _w1_shift); + __m512i _tmp1 = _mm512_unpackhi_epi32(_w0_shift, _w1_shift); + __m512i _tmp2 = _mm512_unpacklo_epi32(_w2_shift, _w3_shift); + __m512i _tmp3 = _mm512_unpackhi_epi32(_w2_shift, _w3_shift); + _w0_shift = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _w1_shift = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _w2_shift = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _w3_shift = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _w_shift = _mm512_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm512_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm512_setzero_si512(); + _w1_shift = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_hc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_hc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_hc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_hc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_hc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_hc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_hc_G_3 + i))); + + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + _w0_shift = _mm512_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm512_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_w0_shift), _mm512_castsi512_ps(_w1_shift), _MM_SHUFFLE(3, 1, 3, 1))); + + _w_shift = _mm512_add_epi32(_w_shift, _tmp0); + _w_shift = _mm512_add_epi32(_w_shift, _tmp1); + } + + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_I_0[i + 2]; + kptr[3] = weight_hc_I_0[i + 3]; + kptr[4] = weight_hc_I_1[i]; + kptr[5] = weight_hc_I_1[i + 1]; + kptr[6] = weight_hc_I_1[i + 2]; + kptr[7] = weight_hc_I_1[i + 3]; + kptr[8 + 0] = weight_hc_I_2[i]; + kptr[8 + 1] = weight_hc_I_2[i + 1]; + kptr[8 + 2] = weight_hc_I_2[i + 2]; + kptr[8 + 3] = weight_hc_I_2[i + 3]; + kptr[8 + 4] = weight_hc_I_3[i]; + kptr[8 + 5] = weight_hc_I_3[i + 1]; + kptr[8 + 6] = weight_hc_I_3[i + 2]; + kptr[8 + 7] = weight_hc_I_3[i + 3]; + kptr[16 + 0] = weight_hc_F_0[i]; + kptr[16 + 1] = weight_hc_F_0[i + 1]; + kptr[16 + 2] = weight_hc_F_0[i + 2]; + kptr[16 + 3] = weight_hc_F_0[i + 3]; + kptr[16 + 4] = weight_hc_F_1[i]; + kptr[16 + 5] = weight_hc_F_1[i + 1]; + kptr[16 + 6] = weight_hc_F_1[i + 2]; + kptr[16 + 7] = weight_hc_F_1[i + 3]; + kptr[24 + 0] = weight_hc_F_2[i]; + kptr[24 + 1] = weight_hc_F_2[i + 1]; + kptr[24 + 2] = weight_hc_F_2[i + 2]; + kptr[24 + 3] = weight_hc_F_2[i + 3]; + kptr[24 + 4] = weight_hc_F_3[i]; + kptr[24 + 5] = weight_hc_F_3[i + 1]; + kptr[24 + 6] = weight_hc_F_3[i + 2]; + kptr[24 + 7] = weight_hc_F_3[i + 3]; + kptr[32 + 0] = weight_hc_O_0[i]; + kptr[32 + 1] = weight_hc_O_0[i + 1]; + kptr[32 + 2] = weight_hc_O_0[i + 2]; + kptr[32 + 3] = weight_hc_O_0[i + 3]; + kptr[32 + 4] = weight_hc_O_1[i]; + kptr[32 + 5] = weight_hc_O_1[i + 1]; + kptr[32 + 6] = weight_hc_O_1[i + 2]; + kptr[32 + 7] = weight_hc_O_1[i + 3]; + kptr[40 + 0] = weight_hc_O_2[i]; + kptr[40 + 1] = weight_hc_O_2[i + 1]; + kptr[40 + 2] = weight_hc_O_2[i + 2]; + kptr[40 + 3] = weight_hc_O_2[i + 3]; + kptr[40 + 4] = weight_hc_O_3[i]; + kptr[40 + 5] = weight_hc_O_3[i + 1]; + kptr[40 + 6] = weight_hc_O_3[i + 2]; + kptr[40 + 7] = weight_hc_O_3[i + 3]; + kptr[48 + 0] = weight_hc_G_0[i]; + kptr[48 + 1] = weight_hc_G_0[i + 1]; + kptr[48 + 2] = weight_hc_G_0[i + 2]; + kptr[48 + 3] = weight_hc_G_0[i + 3]; + kptr[48 + 4] = weight_hc_G_1[i]; + kptr[48 + 5] = weight_hc_G_1[i + 1]; + kptr[48 + 6] = weight_hc_G_1[i + 2]; + kptr[48 + 7] = weight_hc_G_1[i + 3]; + kptr[56 + 0] = weight_hc_G_2[i]; + kptr[56 + 1] = weight_hc_G_2[i + 1]; + kptr[56 + 2] = weight_hc_G_2[i + 2]; + kptr[56 + 3] = weight_hc_G_2[i + 3]; + kptr[56 + 4] = weight_hc_G_3[i]; + kptr[56 + 5] = weight_hc_G_3[i + 1]; + kptr[56 + 6] = weight_hc_G_3[i + 2]; + kptr[56 + 7] = weight_hc_G_3[i + 3]; + + __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); + _w_shift = _mm512_dpbusd_epi32(_w_shift, _v127, _w); + + kptr += 64; + } + + _mm512_storeu_si512((__m512i*)kptr, _w_shift); + kptr += 64; +#else +#if defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 2), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 3), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 4), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 5), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 6), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 7), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 9), _mm_loadl_epi64((const __m128i*)(weight_hc_F_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 10), _mm_loadl_epi64((const __m128i*)(weight_hc_O_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 11), _mm_loadl_epi64((const __m128i*)(weight_hc_G_2 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 12), _mm_loadl_epi64((const __m128i*)(weight_hc_I_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 13), _mm_loadl_epi64((const __m128i*)(weight_hc_F_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 14), _mm_loadl_epi64((const __m128i*)(weight_hc_O_3 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8 * 15), _mm_loadl_epi64((const __m128i*)(weight_hc_G_3 + i))); + kptr += 128; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_I_0[i + 2]; + kptr[3] = weight_hc_I_0[i + 3]; + kptr[4] = weight_hc_I_1[i]; + kptr[5] = weight_hc_I_1[i + 1]; + kptr[6] = weight_hc_I_1[i + 2]; + kptr[7] = weight_hc_I_1[i + 3]; + kptr[8 + 0] = weight_hc_F_0[i]; + kptr[8 + 1] = weight_hc_F_0[i + 1]; + kptr[8 + 2] = weight_hc_F_0[i + 2]; + kptr[8 + 3] = weight_hc_F_0[i + 3]; + kptr[8 + 4] = weight_hc_F_1[i]; + kptr[8 + 5] = weight_hc_F_1[i + 1]; + kptr[8 + 6] = weight_hc_F_1[i + 2]; + kptr[8 + 7] = weight_hc_F_1[i + 3]; + kptr[16 + 0] = weight_hc_O_0[i]; + kptr[16 + 1] = weight_hc_O_0[i + 1]; + kptr[16 + 2] = weight_hc_O_0[i + 2]; + kptr[16 + 3] = weight_hc_O_0[i + 3]; + kptr[16 + 4] = weight_hc_O_1[i]; + kptr[16 + 5] = weight_hc_O_1[i + 1]; + kptr[16 + 6] = weight_hc_O_1[i + 2]; + kptr[16 + 7] = weight_hc_O_1[i + 3]; + kptr[24 + 0] = weight_hc_G_0[i]; + kptr[24 + 1] = weight_hc_G_0[i + 1]; + kptr[24 + 2] = weight_hc_G_0[i + 2]; + kptr[24 + 3] = weight_hc_G_0[i + 3]; + kptr[24 + 4] = weight_hc_G_1[i]; + kptr[24 + 5] = weight_hc_G_1[i + 1]; + kptr[24 + 6] = weight_hc_G_1[i + 2]; + kptr[24 + 7] = weight_hc_G_1[i + 3]; + kptr[32 + 0] = weight_hc_I_2[i]; + kptr[32 + 1] = weight_hc_I_2[i + 1]; + kptr[32 + 2] = weight_hc_I_2[i + 2]; + kptr[32 + 3] = weight_hc_I_2[i + 3]; + kptr[32 + 4] = weight_hc_I_3[i]; + kptr[32 + 5] = weight_hc_I_3[i + 1]; + kptr[32 + 6] = weight_hc_I_3[i + 2]; + kptr[32 + 7] = weight_hc_I_3[i + 3]; + kptr[40 + 0] = weight_hc_F_2[i]; + kptr[40 + 1] = weight_hc_F_2[i + 1]; + kptr[40 + 2] = weight_hc_F_2[i + 2]; + kptr[40 + 3] = weight_hc_F_2[i + 3]; + kptr[40 + 4] = weight_hc_F_3[i]; + kptr[40 + 5] = weight_hc_F_3[i + 1]; + kptr[40 + 6] = weight_hc_F_3[i + 2]; + kptr[40 + 7] = weight_hc_F_3[i + 3]; + kptr[48 + 0] = weight_hc_O_2[i]; + kptr[48 + 1] = weight_hc_O_2[i + 1]; + kptr[48 + 2] = weight_hc_O_2[i + 2]; + kptr[48 + 3] = weight_hc_O_2[i + 3]; + kptr[48 + 4] = weight_hc_O_3[i]; + kptr[48 + 5] = weight_hc_O_3[i + 1]; + kptr[48 + 6] = weight_hc_O_3[i + 2]; + kptr[48 + 7] = weight_hc_O_3[i + 3]; + kptr[56 + 0] = weight_hc_G_2[i]; + kptr[56 + 1] = weight_hc_G_2[i + 1]; + kptr[56 + 2] = weight_hc_G_2[i + 2]; + kptr[56 + 3] = weight_hc_G_2[i + 3]; + kptr[56 + 4] = weight_hc_G_3[i]; + kptr[56 + 5] = weight_hc_G_3[i + 1]; + kptr[56 + 6] = weight_hc_G_3[i + 2]; + kptr[56 + 7] = weight_hc_G_3[i + 3]; + kptr += 64; + } +#endif // __AVX512VNNI__ + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_I_1[i]; + kptr[3] = weight_hc_I_1[i + 1]; + kptr[4] = weight_hc_I_2[i]; + kptr[5] = weight_hc_I_2[i + 1]; + kptr[6] = weight_hc_I_3[i]; + kptr[7] = weight_hc_I_3[i + 1]; + kptr[8 + 0] = weight_hc_F_0[i]; + kptr[8 + 1] = weight_hc_F_0[i + 1]; + kptr[8 + 2] = weight_hc_F_1[i]; + kptr[8 + 3] = weight_hc_F_1[i + 1]; + kptr[8 + 4] = weight_hc_F_2[i]; + kptr[8 + 5] = weight_hc_F_2[i + 1]; + kptr[8 + 6] = weight_hc_F_3[i]; + kptr[8 + 7] = weight_hc_F_3[i + 1]; + kptr[16 + 0] = weight_hc_O_0[i]; + kptr[16 + 1] = weight_hc_O_0[i + 1]; + kptr[16 + 2] = weight_hc_O_1[i]; + kptr[16 + 3] = weight_hc_O_1[i + 1]; + kptr[16 + 4] = weight_hc_O_2[i]; + kptr[16 + 5] = weight_hc_O_2[i + 1]; + kptr[16 + 6] = weight_hc_O_3[i]; + kptr[16 + 7] = weight_hc_O_3[i + 1]; + kptr[24 + 0] = weight_hc_G_0[i]; + kptr[24 + 1] = weight_hc_G_0[i + 1]; + kptr[24 + 2] = weight_hc_G_1[i]; + kptr[24 + 3] = weight_hc_G_1[i + 1]; + kptr[24 + 4] = weight_hc_G_2[i]; + kptr[24 + 5] = weight_hc_G_2[i + 1]; + kptr[24 + 6] = weight_hc_G_3[i]; + kptr[24 + 7] = weight_hc_G_3[i + 1]; + kptr += 32; + } + for (; i < num_output; i++) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_1[i]; + kptr[2] = weight_hc_I_2[i]; + kptr[3] = weight_hc_I_3[i]; + kptr[4] = weight_hc_F_0[i]; + kptr[5] = weight_hc_F_1[i]; + kptr[6] = weight_hc_F_2[i]; + kptr[7] = weight_hc_F_3[i]; + kptr[8 + 0] = weight_hc_O_0[i]; + kptr[8 + 1] = weight_hc_O_1[i]; + kptr[8 + 2] = weight_hc_O_2[i]; + kptr[8 + 3] = weight_hc_O_3[i]; + kptr[8 + 4] = weight_hc_G_0[i]; + kptr[8 + 5] = weight_hc_G_1[i]; + kptr[8 + 6] = weight_hc_G_2[i]; + kptr[8 + 7] = weight_hc_G_3[i]; + kptr += 16; + } + + _mm_storeu_ps(bias_c_IFOG, _mm_loadu_ps(bias_c_I + q)); + + __m128 _descale_xc_I = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 0 + q); + __m128 _descale_xc_F = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 1 + q); + __m128 _descale_xc_O = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 2 + q); + __m128 _descale_xc_G = _mm_loadu_ps(weight_xc_int8_scales_ptr + hidden_size * 3 + q); + __m128 _descale_hc_I = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 0 + q); + __m128 _descale_hc_F = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 1 + q); + __m128 _descale_hc_O = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 2 + q); + __m128 _descale_hc_G = _mm_loadu_ps(weight_hc_int8_scales_ptr + hidden_size * 3 + q); + + __m512 _descale_xc_IFOG = _mm512_castps128_ps512(_descale_xc_I); + _descale_xc_IFOG = _mm512_insertf32x4(_descale_xc_IFOG, _descale_xc_F, 1); + _descale_xc_IFOG = _mm512_insertf32x4(_descale_xc_IFOG, _descale_xc_O, 2); + _descale_xc_IFOG = _mm512_insertf32x4(_descale_xc_IFOG, _descale_xc_G, 3); + __m512 _descale_hc_IFOG = _mm512_castps128_ps512(_descale_hc_I); + _descale_hc_IFOG = _mm512_insertf32x4(_descale_hc_IFOG, _descale_hc_F, 1); + _descale_hc_IFOG = _mm512_insertf32x4(_descale_hc_IFOG, _descale_hc_O, 2); + _descale_hc_IFOG = _mm512_insertf32x4(_descale_hc_IFOG, _descale_hc_G, 3); + + _descale_xc_IFOG = _mm512_div_ps(_mm512_set1_ps(1.f), _descale_xc_IFOG); + _descale_hc_IFOG = _mm512_div_ps(_mm512_set1_ps(1.f), _descale_hc_IFOG); + + _mm512_storeu_ps(descales_ptr, _descale_xc_IFOG); + _mm512_storeu_ps(descales_ptr + 16, _descale_hc_IFOG); + } +#endif // __AVX512F__ + for (; q + 1 < hidden_size; q += 2) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + bias_c_IFOG[4] = bias_c_I[q + 1]; + bias_c_IFOG[5] = bias_c_F[q + 1]; + bias_c_IFOG[6] = bias_c_O[q + 1]; + bias_c_IFOG[7] = bias_c_G[q + 1]; + + bias_c_IFOG += 8; + + const signed char* weight_xc_I_0 = weight_xc_dr.row(hidden_size * 0 + q); + const signed char* weight_xc_F_0 = weight_xc_dr.row(hidden_size * 1 + q); + const signed char* weight_xc_O_0 = weight_xc_dr.row(hidden_size * 2 + q); + const signed char* weight_xc_G_0 = weight_xc_dr.row(hidden_size * 3 + q); + const signed char* weight_xc_I_1 = weight_xc_dr.row(hidden_size * 0 + q + 1); + const signed char* weight_xc_F_1 = weight_xc_dr.row(hidden_size * 1 + q + 1); + const signed char* weight_xc_O_1 = weight_xc_dr.row(hidden_size * 2 + q + 1); + const signed char* weight_xc_G_1 = weight_xc_dr.row(hidden_size * 3 + q + 1); + + const signed char* weight_hc_I_0 = weight_hc_dr.row(hidden_size * 0 + q); + const signed char* weight_hc_F_0 = weight_hc_dr.row(hidden_size * 1 + q); + const signed char* weight_hc_O_0 = weight_hc_dr.row(hidden_size * 2 + q); + const signed char* weight_hc_G_0 = weight_hc_dr.row(hidden_size * 3 + q); + const signed char* weight_hc_I_1 = weight_hc_dr.row(hidden_size * 0 + q + 1); + const signed char* weight_hc_F_1 = weight_hc_dr.row(hidden_size * 1 + q + 1); + const signed char* weight_hc_O_1 = weight_hc_dr.row(hidden_size * 2 + q + 1); + const signed char* weight_hc_G_1 = weight_hc_dr.row(hidden_size * 3 + q + 1); + +#if __AVX512F__ + signed char* kptr = weight_data_tm_dr.row(q / 4 + (q % 4) / 2); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4 + (q % 4) / 2); +#else + signed char* kptr = weight_data_tm_dr.row(q / 2); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 2); +#endif + + int i = 0; +#if __AVXVNNI__ || __AVX512VNNI__ + __m256i _w_shift = _mm256_setzero_si256(); + __m256i _v127 = _mm256_set1_epi8(127); + + __m256i _w0_shift = _mm256_setzero_si256(); + __m256i _w1_shift = _mm256_setzero_si256(); +#if defined(__x86_64__) || defined(_M_X64) + __m256i _w2_shift = _mm256_setzero_si256(); + __m256i _w3_shift = _mm256_setzero_si256(); + for (; i + 15 < size; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_xc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_xc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_xc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_xc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_xc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_xc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_xc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_xc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm256_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm256_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 128; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + __m256i _tmp1 = _mm256_hadd_epi32(_w2_shift, _w3_shift); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + + _w0_shift = _mm256_setzero_si256(); + _w1_shift = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_I_0[i + 2]; + kptr[3] = weight_xc_I_0[i + 3]; + kptr[4] = weight_xc_F_0[i]; + kptr[5] = weight_xc_F_0[i + 1]; + kptr[6] = weight_xc_F_0[i + 2]; + kptr[7] = weight_xc_F_0[i + 3]; + kptr[8 + 0] = weight_xc_O_0[i]; + kptr[8 + 1] = weight_xc_O_0[i + 1]; + kptr[8 + 2] = weight_xc_O_0[i + 2]; + kptr[8 + 3] = weight_xc_O_0[i + 3]; + kptr[8 + 4] = weight_xc_G_0[i]; + kptr[8 + 5] = weight_xc_G_0[i + 1]; + kptr[8 + 6] = weight_xc_G_0[i + 2]; + kptr[8 + 7] = weight_xc_G_0[i + 3]; + kptr[16 + 0] = weight_xc_I_1[i]; + kptr[16 + 1] = weight_xc_I_1[i + 1]; + kptr[16 + 2] = weight_xc_I_1[i + 2]; + kptr[16 + 3] = weight_xc_I_1[i + 3]; + kptr[16 + 4] = weight_xc_F_1[i]; + kptr[16 + 5] = weight_xc_F_1[i + 1]; + kptr[16 + 6] = weight_xc_F_1[i + 2]; + kptr[16 + 7] = weight_xc_F_1[i + 3]; + kptr[24 + 0] = weight_xc_O_1[i]; + kptr[24 + 1] = weight_xc_O_1[i + 1]; + kptr[24 + 2] = weight_xc_O_1[i + 2]; + kptr[24 + 3] = weight_xc_O_1[i + 3]; + kptr[24 + 4] = weight_xc_G_1[i]; + kptr[24 + 5] = weight_xc_G_1[i + 1]; + kptr[24 + 6] = weight_xc_G_1[i + 2]; + kptr[24 + 7] = weight_xc_G_1[i + 3]; + + __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); + _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _w); + + kptr += 32; + } + + _mm256_storeu_si256((__m256i*)kptr, _w_shift); + kptr += 32; +#else +#if defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_xc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_xc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_xc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_xc_G_1 + i))); + kptr += 64; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_I_0[i + 2]; + kptr[3] = weight_xc_I_0[i + 3]; + kptr[4] = weight_xc_F_0[i]; + kptr[5] = weight_xc_F_0[i + 1]; + kptr[6] = weight_xc_F_0[i + 2]; + kptr[7] = weight_xc_F_0[i + 3]; + kptr[8 + 0] = weight_xc_I_1[i]; + kptr[8 + 1] = weight_xc_I_1[i + 1]; + kptr[8 + 2] = weight_xc_I_1[i + 2]; + kptr[8 + 3] = weight_xc_I_1[i + 3]; + kptr[8 + 4] = weight_xc_F_1[i]; + kptr[8 + 5] = weight_xc_F_1[i + 1]; + kptr[8 + 6] = weight_xc_F_1[i + 2]; + kptr[8 + 7] = weight_xc_F_1[i + 3]; + kptr[16 + 0] = weight_xc_O_0[i]; + kptr[16 + 1] = weight_xc_O_0[i + 1]; + kptr[16 + 2] = weight_xc_O_0[i + 2]; + kptr[16 + 3] = weight_xc_O_0[i + 3]; + kptr[16 + 4] = weight_xc_G_0[i]; + kptr[16 + 5] = weight_xc_G_0[i + 1]; + kptr[16 + 6] = weight_xc_G_0[i + 2]; + kptr[16 + 7] = weight_xc_G_0[i + 3]; + kptr[24 + 0] = weight_xc_O_1[i]; + kptr[24 + 1] = weight_xc_O_1[i + 1]; + kptr[24 + 2] = weight_xc_O_1[i + 2]; + kptr[24 + 3] = weight_xc_O_1[i + 3]; + kptr[24 + 4] = weight_xc_G_1[i]; + kptr[24 + 5] = weight_xc_G_1[i + 1]; + kptr[24 + 6] = weight_xc_G_1[i + 2]; + kptr[24 + 7] = weight_xc_G_1[i + 3]; + kptr += 32; + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_I_0[i + 1]; + kptr[2] = weight_xc_F_0[i]; + kptr[3] = weight_xc_F_0[i + 1]; + kptr[4] = weight_xc_O_0[i]; + kptr[5] = weight_xc_O_0[i + 1]; + kptr[6] = weight_xc_G_0[i]; + kptr[7] = weight_xc_G_0[i + 1]; + kptr[8 + 0] = weight_xc_I_1[i]; + kptr[8 + 1] = weight_xc_I_1[i + 1]; + kptr[8 + 2] = weight_xc_F_1[i]; + kptr[8 + 3] = weight_xc_F_1[i + 1]; + kptr[8 + 4] = weight_xc_O_1[i]; + kptr[8 + 5] = weight_xc_O_1[i + 1]; + kptr[8 + 6] = weight_xc_G_1[i]; + kptr[8 + 7] = weight_xc_G_1[i + 1]; + kptr += 16; + } + for (; i < size; i++) + { + kptr[0] = weight_xc_I_0[i]; + kptr[1] = weight_xc_F_0[i]; + kptr[2] = weight_xc_O_0[i]; + kptr[3] = weight_xc_G_0[i]; + kptr[4] = weight_xc_I_1[i]; + kptr[5] = weight_xc_F_1[i]; + kptr[6] = weight_xc_O_1[i]; + kptr[7] = weight_xc_G_1[i]; + kptr += 8; + } + + i = 0; +#if __AVXVNNI__ || __AVX512VNNI__ + _w_shift = _mm256_setzero_si256(); + _v127 = _mm256_set1_epi8(127); + _w0_shift = _mm256_setzero_si256(); + _w1_shift = _mm256_setzero_si256(); +#if defined(__x86_64__) || defined(_M_X64) + _w2_shift = _mm256_setzero_si256(); + _w3_shift = _mm256_setzero_si256(); + for (; i + 15 < num_output; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_hc_I_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_hc_I_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 2), _mm_loadu_si128((const __m128i*)(weight_hc_F_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 3), _mm_loadu_si128((const __m128i*)(weight_hc_F_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 4), _mm_loadu_si128((const __m128i*)(weight_hc_O_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 5), _mm_loadu_si128((const __m128i*)(weight_hc_O_1 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 6), _mm_loadu_si128((const __m128i*)(weight_hc_G_0 + i))); + _mm_storeu_si128((__m128i*)(kptr + 16 * 7), _mm_loadu_si128((const __m128i*)(weight_hc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm256_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm256_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 128; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + __m256i _tmp1 = _mm256_hadd_epi32(_w2_shift, _w3_shift); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + + _w0_shift = _mm256_setzero_si256(); + _w1_shift = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + _w0_shift = _mm256_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm256_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm256_add_epi32(_w_shift, _tmp0); + } + + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_I_0[i + 2]; + kptr[3] = weight_hc_I_0[i + 3]; + kptr[4] = weight_hc_F_0[i]; + kptr[5] = weight_hc_F_0[i + 1]; + kptr[6] = weight_hc_F_0[i + 2]; + kptr[7] = weight_hc_F_0[i + 3]; + kptr[8 + 0] = weight_hc_O_0[i]; + kptr[8 + 1] = weight_hc_O_0[i + 1]; + kptr[8 + 2] = weight_hc_O_0[i + 2]; + kptr[8 + 3] = weight_hc_O_0[i + 3]; + kptr[8 + 4] = weight_hc_G_0[i]; + kptr[8 + 5] = weight_hc_G_0[i + 1]; + kptr[8 + 6] = weight_hc_G_0[i + 2]; + kptr[8 + 7] = weight_hc_G_0[i + 3]; + kptr[16 + 0] = weight_hc_I_1[i]; + kptr[16 + 1] = weight_hc_I_1[i + 1]; + kptr[16 + 2] = weight_hc_I_1[i + 2]; + kptr[16 + 3] = weight_hc_I_1[i + 3]; + kptr[16 + 4] = weight_hc_F_1[i]; + kptr[16 + 5] = weight_hc_F_1[i + 1]; + kptr[16 + 6] = weight_hc_F_1[i + 2]; + kptr[16 + 7] = weight_hc_F_1[i + 3]; + kptr[24 + 0] = weight_hc_O_1[i]; + kptr[24 + 1] = weight_hc_O_1[i + 1]; + kptr[24 + 2] = weight_hc_O_1[i + 2]; + kptr[24 + 3] = weight_hc_O_1[i + 3]; + kptr[24 + 4] = weight_hc_G_1[i]; + kptr[24 + 5] = weight_hc_G_1[i + 1]; + kptr[24 + 6] = weight_hc_G_1[i + 2]; + kptr[24 + 7] = weight_hc_G_1[i + 3]; + + __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); + _w_shift = _mm256_dpbusd_epi32(_w_shift, _v127, _w); + + kptr += 32; + } + + _mm256_storeu_si256((__m256i*)kptr, _w_shift); + kptr += 32; +#else +#if defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_I_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_F_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_F_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 32), _mm_loadl_epi64((const __m128i*)(weight_hc_O_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 40), _mm_loadl_epi64((const __m128i*)(weight_hc_O_1 + i))); + _mm_storel_epi64((__m128i*)(kptr + 48), _mm_loadl_epi64((const __m128i*)(weight_hc_G_0 + i))); + _mm_storel_epi64((__m128i*)(kptr + 56), _mm_loadl_epi64((const __m128i*)(weight_hc_G_1 + i))); + kptr += 64; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_I_0[i + 2]; + kptr[3] = weight_hc_I_0[i + 3]; + kptr[4] = weight_hc_F_0[i]; + kptr[5] = weight_hc_F_0[i + 1]; + kptr[6] = weight_hc_F_0[i + 2]; + kptr[7] = weight_hc_F_0[i + 3]; + kptr[8 + 0] = weight_hc_I_1[i]; + kptr[8 + 1] = weight_hc_I_1[i + 1]; + kptr[8 + 2] = weight_hc_I_1[i + 2]; + kptr[8 + 3] = weight_hc_I_1[i + 3]; + kptr[8 + 4] = weight_hc_F_1[i]; + kptr[8 + 5] = weight_hc_F_1[i + 1]; + kptr[8 + 6] = weight_hc_F_1[i + 2]; + kptr[8 + 7] = weight_hc_F_1[i + 3]; + kptr[16 + 0] = weight_hc_O_0[i]; + kptr[16 + 1] = weight_hc_O_0[i + 1]; + kptr[16 + 2] = weight_hc_O_0[i + 2]; + kptr[16 + 3] = weight_hc_O_0[i + 3]; + kptr[16 + 4] = weight_hc_G_0[i]; + kptr[16 + 5] = weight_hc_G_0[i + 1]; + kptr[16 + 6] = weight_hc_G_0[i + 2]; + kptr[16 + 7] = weight_hc_G_0[i + 3]; + kptr[24 + 0] = weight_hc_O_1[i]; + kptr[24 + 1] = weight_hc_O_1[i + 1]; + kptr[24 + 2] = weight_hc_O_1[i + 2]; + kptr[24 + 3] = weight_hc_O_1[i + 3]; + kptr[24 + 4] = weight_hc_G_1[i]; + kptr[24 + 5] = weight_hc_G_1[i + 1]; + kptr[24 + 6] = weight_hc_G_1[i + 2]; + kptr[24 + 7] = weight_hc_G_1[i + 3]; + kptr += 32; + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_I_0[i + 1]; + kptr[2] = weight_hc_F_0[i]; + kptr[3] = weight_hc_F_0[i + 1]; + kptr[4] = weight_hc_O_0[i]; + kptr[5] = weight_hc_O_0[i + 1]; + kptr[6] = weight_hc_G_0[i]; + kptr[7] = weight_hc_G_0[i + 1]; + kptr[8 + 0] = weight_hc_I_1[i]; + kptr[8 + 1] = weight_hc_I_1[i + 1]; + kptr[8 + 2] = weight_hc_F_1[i]; + kptr[8 + 3] = weight_hc_F_1[i + 1]; + kptr[8 + 4] = weight_hc_O_1[i]; + kptr[8 + 5] = weight_hc_O_1[i + 1]; + kptr[8 + 6] = weight_hc_G_1[i]; + kptr[8 + 7] = weight_hc_G_1[i + 1]; + kptr += 16; + } + for (; i < num_output; i++) + { + kptr[0] = weight_hc_I_0[i]; + kptr[1] = weight_hc_F_0[i]; + kptr[2] = weight_hc_O_0[i]; + kptr[3] = weight_hc_G_0[i]; + kptr[4] = weight_hc_I_1[i]; + kptr[5] = weight_hc_F_1[i]; + kptr[6] = weight_hc_O_1[i]; + kptr[7] = weight_hc_G_1[i]; + kptr += 8; + } + + descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[2] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[3] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q]; + descales_ptr[4] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q + 1]; + descales_ptr[5] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q + 1]; + descales_ptr[6] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q + 1]; + descales_ptr[7] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q + 1]; + descales_ptr[8 + 0] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[8 + 1] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[8 + 2] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[8 + 3] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q]; + descales_ptr[8 + 4] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q + 1]; + descales_ptr[8 + 5] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q + 1]; + descales_ptr[8 + 6] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q + 1]; + descales_ptr[8 + 7] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q + 1]; + } +#endif // __AVX2__ + for (; q < hidden_size; q++) + { + bias_c_IFOG[0] = bias_c_I[q]; + bias_c_IFOG[1] = bias_c_F[q]; + bias_c_IFOG[2] = bias_c_O[q]; + bias_c_IFOG[3] = bias_c_G[q]; + + bias_c_IFOG += 4; + + const signed char* weight_xc_I = weight_xc_dr.row(hidden_size * 0 + q); + const signed char* weight_xc_F = weight_xc_dr.row(hidden_size * 1 + q); + const signed char* weight_xc_O = weight_xc_dr.row(hidden_size * 2 + q); + const signed char* weight_xc_G = weight_xc_dr.row(hidden_size * 3 + q); + + const signed char* weight_hc_I = weight_hc_dr.row(hidden_size * 0 + q); + const signed char* weight_hc_F = weight_hc_dr.row(hidden_size * 1 + q); + const signed char* weight_hc_O = weight_hc_dr.row(hidden_size * 2 + q); + const signed char* weight_hc_G = weight_hc_dr.row(hidden_size * 3 + q); + +#if __AVX512F__ + signed char* kptr = weight_data_tm_dr.row(q / 4 + (q % 4) / 2 + q % 2); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 4 + (q % 4) / 2 + q % 2); +#elif __AVX2__ + signed char* kptr = weight_data_tm_dr.row(q / 2 + q % 2); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q / 2 + q % 2); +#else + signed char* kptr = weight_data_tm_dr.row(q); + float* descales_ptr = weight_data_tm_int8_descales_dr.row(q); +#endif + + int i = 0; +#if __SSE2__ +#if __AVXVNNI__ || __AVX512VNNI__ + __m128i _w_shift = _mm_setzero_si128(); + __m128i _v127 = _mm_set1_epi8(127); + __m128i _w0_shift = _mm_setzero_si128(); + __m128i _w1_shift = _mm_setzero_si128(); +#if defined(__x86_64__) || defined(_M_X64) + __m128i _w2_shift = _mm_setzero_si128(); + __m128i _w3_shift = _mm_setzero_si128(); + for (; i + 15 < size; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_xc_I + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_xc_F + i))); + _mm_storeu_si128((__m128i*)(kptr + 32), _mm_loadu_si128((const __m128i*)(weight_xc_O + i))); + _mm_storeu_si128((__m128i*)(kptr + 48), _mm_loadu_si128((const __m128i*)(weight_xc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 64; + } + { + transpose4x4_epi32(_w0_shift, _w1_shift, _w2_shift, _w3_shift); + _w_shift = _mm_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm_setzero_si128(); + _w1_shift = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 32; + } + { + __m128i _tmp0 = _mm_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _tmp0); + } + + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_I[i + 2]; + kptr[3] = weight_xc_I[i + 3]; + kptr[4] = weight_xc_F[i]; + kptr[5] = weight_xc_F[i + 1]; + kptr[6] = weight_xc_F[i + 2]; + kptr[7] = weight_xc_F[i + 3]; + kptr[8 + 0] = weight_xc_O[i]; + kptr[8 + 1] = weight_xc_O[i + 1]; + kptr[8 + 2] = weight_xc_O[i + 2]; + kptr[8 + 3] = weight_xc_O[i + 3]; + kptr[8 + 4] = weight_xc_G[i]; + kptr[8 + 5] = weight_xc_G[i + 1]; + kptr[8 + 6] = weight_xc_G[i + 2]; + kptr[8 + 7] = weight_xc_G[i + 3]; + + __m128i _w = _mm_loadu_si128((const __m128i*)kptr); + _w_shift = _mm_dpbusd_epi32(_w_shift, _v127, _w); + + kptr += 16; + } + + _mm_storeu_si128((__m128i*)kptr, _w_shift); + kptr += 16; +#else +#if defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_xc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_xc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_xc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_xc_G + i))); + kptr += 32; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < size; i += 4) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_I[i + 2]; + kptr[3] = weight_xc_I[i + 3]; + kptr[4] = weight_xc_F[i]; + kptr[5] = weight_xc_F[i + 1]; + kptr[6] = weight_xc_F[i + 2]; + kptr[7] = weight_xc_F[i + 3]; + kptr[8 + 0] = weight_xc_O[i]; + kptr[8 + 1] = weight_xc_O[i + 1]; + kptr[8 + 2] = weight_xc_O[i + 2]; + kptr[8 + 3] = weight_xc_O[i + 3]; + kptr[8 + 4] = weight_xc_G[i]; + kptr[8 + 5] = weight_xc_G[i + 1]; + kptr[8 + 6] = weight_xc_G[i + 2]; + kptr[8 + 7] = weight_xc_G[i + 3]; + kptr += 16; + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < size; i += 2) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_I[i + 1]; + kptr[2] = weight_xc_F[i]; + kptr[3] = weight_xc_F[i + 1]; + kptr[4] = weight_xc_O[i]; + kptr[5] = weight_xc_O[i + 1]; + kptr[6] = weight_xc_G[i]; + kptr[7] = weight_xc_G[i + 1]; + kptr += 8; + } +#endif // __SSE2__ + for (; i < size; i++) + { + kptr[0] = weight_xc_I[i]; + kptr[1] = weight_xc_F[i]; + kptr[2] = weight_xc_O[i]; + kptr[3] = weight_xc_G[i]; + kptr += 4; + } + + i = 0; +#if __SSE2__ +#if __AVXVNNI__ || __AVX512VNNI__ + _w_shift = _mm_setzero_si128(); + _w0_shift = _mm_setzero_si128(); + _w1_shift = _mm_setzero_si128(); +#if defined(__x86_64__) || defined(_M_X64) + _w2_shift = _mm_setzero_si128(); + _w3_shift = _mm_setzero_si128(); + for (; i + 15 < num_output; i += 16) + { + _mm_storeu_si128((__m128i*)kptr, _mm_loadu_si128((const __m128i*)(weight_hc_I + i))); + _mm_storeu_si128((__m128i*)(kptr + 16), _mm_loadu_si128((const __m128i*)(weight_hc_F + i))); + _mm_storeu_si128((__m128i*)(kptr + 32), _mm_loadu_si128((const __m128i*)(weight_hc_O + i))); + _mm_storeu_si128((__m128i*)(kptr + 48), _mm_loadu_si128((const __m128i*)(weight_hc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + _w2_shift = _mm_dpbusd_epi32(_w2_shift, _v127, _w2); + _w3_shift = _mm_dpbusd_epi32(_w3_shift, _v127, _w3); + + kptr += 64; + } + { + transpose4x4_epi32(_w0_shift, _w1_shift, _w2_shift, _w3_shift); + _w_shift = _mm_add_epi32(_w_shift, _w0_shift); + _w_shift = _mm_add_epi32(_w_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _w2_shift); + _w_shift = _mm_add_epi32(_w_shift, _w3_shift); + } + + _w0_shift = _mm_setzero_si128(); + _w1_shift = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_G + i))); + + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + _w0_shift = _mm_dpbusd_epi32(_w0_shift, _v127, _w0); + _w1_shift = _mm_dpbusd_epi32(_w1_shift, _v127, _w1); + + kptr += 32; + } + { + __m128i _tmp0 = _mm_hadd_epi32(_w0_shift, _w1_shift); + _w_shift = _mm_add_epi32(_w_shift, _tmp0); + } + + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_I[i + 2]; + kptr[3] = weight_hc_I[i + 3]; + kptr[4] = weight_hc_F[i]; + kptr[5] = weight_hc_F[i + 1]; + kptr[6] = weight_hc_F[i + 2]; + kptr[7] = weight_hc_F[i + 3]; + kptr[8 + 0] = weight_hc_O[i]; + kptr[8 + 1] = weight_hc_O[i + 1]; + kptr[8 + 2] = weight_hc_O[i + 2]; + kptr[8 + 3] = weight_hc_O[i + 3]; + kptr[8 + 4] = weight_hc_G[i]; + kptr[8 + 5] = weight_hc_G[i + 1]; + kptr[8 + 6] = weight_hc_G[i + 2]; + kptr[8 + 7] = weight_hc_G[i + 3]; + + __m128i _w = _mm_loadu_si128((const __m128i*)kptr); + _w_shift = _mm_dpbusd_epi32(_w_shift, _v127, _w); + + kptr += 16; + } + + _mm_storeu_si128((__m128i*)kptr, _w_shift); + kptr += 16; +#else +#if defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + _mm_storel_epi64((__m128i*)kptr, _mm_loadl_epi64((const __m128i*)(weight_hc_I + i))); + _mm_storel_epi64((__m128i*)(kptr + 8), _mm_loadl_epi64((const __m128i*)(weight_hc_F + i))); + _mm_storel_epi64((__m128i*)(kptr + 16), _mm_loadl_epi64((const __m128i*)(weight_hc_O + i))); + _mm_storel_epi64((__m128i*)(kptr + 24), _mm_loadl_epi64((const __m128i*)(weight_hc_G + i))); + kptr += 32; + } +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < num_output; i += 4) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_I[i + 2]; + kptr[3] = weight_hc_I[i + 3]; + kptr[4] = weight_hc_F[i]; + kptr[5] = weight_hc_F[i + 1]; + kptr[6] = weight_hc_F[i + 2]; + kptr[7] = weight_hc_F[i + 3]; + kptr[8 + 0] = weight_hc_O[i]; + kptr[8 + 1] = weight_hc_O[i + 1]; + kptr[8 + 2] = weight_hc_O[i + 2]; + kptr[8 + 3] = weight_hc_O[i + 3]; + kptr[8 + 4] = weight_hc_G[i]; + kptr[8 + 5] = weight_hc_G[i + 1]; + kptr[8 + 6] = weight_hc_G[i + 2]; + kptr[8 + 7] = weight_hc_G[i + 3]; + kptr += 16; + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < num_output; i += 2) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_I[i + 1]; + kptr[2] = weight_hc_F[i]; + kptr[3] = weight_hc_F[i + 1]; + kptr[4] = weight_hc_O[i]; + kptr[5] = weight_hc_O[i + 1]; + kptr[6] = weight_hc_G[i]; + kptr[7] = weight_hc_G[i + 1]; + kptr += 8; + } +#endif // __SSE2__ + for (; i < num_output; i++) + { + kptr[0] = weight_hc_I[i]; + kptr[1] = weight_hc_F[i]; + kptr[2] = weight_hc_O[i]; + kptr[3] = weight_hc_G[i]; + kptr += 4; + } + + descales_ptr[0] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[1] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[2] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[3] = 1.f / weight_xc_int8_scales_ptr[hidden_size * 3 + q]; + descales_ptr[4] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 0 + q]; + descales_ptr[5] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 1 + q]; + descales_ptr[6] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 2 + q]; + descales_ptr[7] = 1.f / weight_hc_int8_scales_ptr[hidden_size * 3 + q]; + } + } +} + +static void lstm_int8(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ +#if NCNN_RUNTIME_CPU && NCNN_AVX512VNNI && __AVX512F__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx512_vnni()) + { + lstm_int8_avx512vnni(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVXVNNI && __AVX2__ && !__AVX512F__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx_vnni()) + { + lstm_int8_avxvnni(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_AVX2 && __AVX__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_avx2()) + { + lstm_int8_avx2(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); + return; + } +#endif + +#if NCNN_RUNTIME_CPU && NCNN_XOP && __SSE2__ && !__XOP__ && !__AVX2__ && !__AVXVNNI__ && !__AVX512VNNI__ + if (ncnn::cpu_support_x86_xop()) + { + lstm_int8_xop(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); + return; + } +#endif + + int size = bottom_blob_int8.w; + int T = bottom_blob_int8.h; + + int num_output = top_blob.w; + int hidden_size = cell_state.w; + + // 4 x hidden_size + Mat gates(4, hidden_size, 4u, opt.workspace_allocator); + + Mat tmp_hidden_state; + if (num_output != hidden_size) + { + tmp_hidden_state.create(hidden_size, 4u, opt.workspace_allocator); + } + + Mat hidden_state_int8(num_output, (size_t)1u, 1, opt.workspace_allocator); + float hidden_state_int8_scale = 1.f; + float hidden_state_int8_descale = 1.f; + + // unroll + for (int t = 0; t < T; t++) + { + int ti = reverse ? T - 1 - t : t; + + // dynamic quantize hidden_state + { + float absmax = 0.f; + for (int i = 0; i < num_output; i++) + { + absmax = std::max(absmax, (float)fabs(hidden_state[i])); + } + + if (absmax == 0.f) + { + hidden_state_int8.fill(0); + } + else + { + hidden_state_int8_scale = 127.f / absmax; + hidden_state_int8_descale = absmax / 127.f; + + signed char* hs = hidden_state_int8; + for (int i = 0; i < num_output; i++) + { + hs[i] = float2int8(hidden_state[i] * hidden_state_int8_scale); + } + } + } + + int remain_hidden_size_start = 0; + int nn_hidden_size = 0; +#if __AVX2__ +#if __AVX512F__ + nn_hidden_size = hidden_size >> 2; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 4; + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + + const signed char* kptr = weight_data_tm.row(q / 4); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4); + + float* gates_data = gates.row(q); + + __m512i _lstm_IFOGx0 = _mm512_setzero_si512(); + __m512i _sum0 = _mm512_setzero_si512(); + __m512i _sum1 = _mm512_setzero_si512(); + int i = 0; +#if __AVX512VNNI__ + __m128i _v127q = _mm_set1_epi8(127); + __m512i _v127 = _mm512_set1_epi8(127); + +#if defined(__x86_64__) || defined(_M_X64) + __m512i _sum2 = _mm512_setzero_si512(); + __m512i _sum3 = _mm512_setzero_si512(); + for (; i + 15 < size; i += 16) + { + __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); + __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); + + _xi = _mm_add_epi8(_xi, _v127q); + __m512i _xii = _mm512_broadcast_i32x4(_xi); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _xii, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _xii, _w1); + _sum2 = _mm512_dpbusd_epi32(_sum2, _xii, _w2); + _sum3 = _mm512_dpbusd_epi32(_sum3, _xii, _w3); + + kptr += 256; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + __m128i _xi = _mm_loadl_epi64((const __m128i*)(x + i)); + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + + _xi = _mm_add_epi8(_xi, _v127q); + __m512i _xii = _mm512_broadcastq_epi64(_xi); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _xii, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _xii, _w1); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp1); + } + + for (; i + 3 < size; i += 4) + { + __m512i _xi = _mm512_set1_epi32(((const int*)(x + i))[0]); + __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); + + _xi = _mm512_add_epi8(_xi, _v127); + _lstm_IFOGx0 = _mm512_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); + + kptr += 64; + } + { + __m512i _w_shift = _mm512_loadu_si512((const __m512i*)kptr); + _lstm_IFOGx0 = _mm512_sub_epi32(_lstm_IFOGx0, _w_shift); + kptr += 64; + } +#else +#if defined(__x86_64__) || defined(_M_X64) + __m512i _sum2 = _mm512_setzero_si512(); + __m512i _sum3 = _mm512_setzero_si512(); + for (; i + 7 < size; i += 8) + { + __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + + __m512i _xii = _mm512_cvtepi8_epi16(_xi); + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + __m512i _ww2 = _mm512_cvtepi8_epi16(_w2); + __m512i _ww3 = _mm512_cvtepi8_epi16(_w3); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _xii); + __m512i _s1 = _mm512_madd_epi16(_ww1, _xii); + __m512i _s2 = _mm512_madd_epi16(_ww2, _xii); + __m512i _s3 = _mm512_madd_epi16(_ww3, _xii); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < size; i += 4) + { + __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + + __m512i _xii = _mm512_cvtepi8_epi16(_xi); + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _xii); + __m512i _s1 = _mm512_madd_epi16(_ww1, _xii); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + kptr += 64; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp0); + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _tmp1); + } +#endif // __AVX512VNNI__ + for (; i + 1 < size; i += 2) + { + __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); + + __m512i _ww = _mm512_cvtepi8_epi16(_w); + __m512i _xixi = _mm512_cvtepi8_epi16(_xi); + + __m512i _xixi0 = _mm512_shuffle_epi32(_xixi, _MM_PERM_AAAA); + +#if __AVX512VNNI__ + _lstm_IFOGx0 = _mm512_dpwssd_epi32(_lstm_IFOGx0, _ww, _xixi0); +#else + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _mm512_madd_epi16(_ww, _xixi0)); +#endif // __AVX512VNNI__ + + kptr += 32; + } + for (; i < size; i++) + { + __m128i _w = _mm_load_si128((const __m128i*)kptr); + __m256i _xi = _mm256_set1_epi16(x[i]); + + __m256i _ww = _mm256_cvtepi8_epi16(_w); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_ww, _xi)); + + _lstm_IFOGx0 = _mm512_add_epi32(_lstm_IFOGx0, _s0); + + kptr += 16; + } + + __m512i _lstm_IFOGh0 = _mm512_setzero_si512(); + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); + i = 0; +#if __AVX512VNNI__ +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + for (; i + 15 < num_output; i += 16) + { + __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + __m512i _w2 = _mm512_loadu_si512((const __m512i*)(kptr + 128)); + __m512i _w3 = _mm512_loadu_si512((const __m512i*)(kptr + 192)); + + _h_cont = _mm_add_epi8(_h_cont, _v127q); + __m512i _hh_cont = _mm512_broadcast_i32x4(_h_cont); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _hh_cont, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _hh_cont, _w1); + _sum2 = _mm512_dpbusd_epi32(_sum2, _hh_cont, _w2); + _sum3 = _mm512_dpbusd_epi32(_sum3, _hh_cont, _w3); + + kptr += 256; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + __m128i _h_cont = _mm_loadl_epi64((const __m128i*)(hs + i)); + __m512i _w0 = _mm512_loadu_si512((const __m512i*)kptr); + __m512i _w1 = _mm512_loadu_si512((const __m512i*)(kptr + 64)); + + _h_cont = _mm_add_epi8(_h_cont, _v127q); + __m512i _hh_cont = _mm512_broadcastq_epi64(_h_cont); + + _sum0 = _mm512_dpbusd_epi32(_sum0, _hh_cont, _w0); + _sum1 = _mm512_dpbusd_epi32(_sum1, _hh_cont, _w1); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp1); + } + + for (; i + 3 < num_output; i += 4) + { + __m512i _h_cont = _mm512_set1_epi32(((const int*)(hs + i))[0]); + __m512i _w = _mm512_loadu_si512((const __m512i*)kptr); + + _h_cont = _mm512_add_epi8(_h_cont, _v127); + _lstm_IFOGh0 = _mm512_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); + + kptr += 64; + } + { + __m512i _w_shift = _mm512_loadu_si512((const __m512i*)kptr); + _lstm_IFOGh0 = _mm512_sub_epi32(_lstm_IFOGh0, _w_shift); + kptr += 64; + } +#else +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm512_setzero_si512(); + _sum3 = _mm512_setzero_si512(); + for (; i + 7 < num_output; i += 8) + { + __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + + __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + __m512i _ww2 = _mm512_cvtepi8_epi16(_w2); + __m512i _ww3 = _mm512_cvtepi8_epi16(_w3); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont); + __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont); + __m512i _s2 = _mm512_madd_epi16(_ww2, _hh_cont); + __m512i _s3 = _mm512_madd_epi16(_ww3, _hh_cont); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + _sum2 = _mm512_add_epi32(_sum2, _s2); + _sum3 = _mm512_add_epi32(_sum3, _s3); + + kptr += 128; + } + { + __m512i _tmp0 = _mm512_unpacklo_epi32(_sum0, _sum1); + __m512i _tmp1 = _mm512_unpackhi_epi32(_sum0, _sum1); + __m512i _tmp2 = _mm512_unpacklo_epi32(_sum2, _sum3); + __m512i _tmp3 = _mm512_unpackhi_epi32(_sum2, _sum3); + _sum0 = _mm512_unpacklo_epi64(_tmp0, _tmp2); + _sum1 = _mm512_unpackhi_epi64(_tmp0, _tmp2); + _sum2 = _mm512_unpacklo_epi64(_tmp1, _tmp3); + _sum3 = _mm512_unpackhi_epi64(_tmp1, _tmp3); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm512_setzero_si512(); + _sum1 = _mm512_setzero_si512(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < num_output; i += 4) + { + __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + + __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); + __m512i _ww0 = _mm512_cvtepi8_epi16(_w0); + __m512i _ww1 = _mm512_cvtepi8_epi16(_w1); + + __m512i _s0 = _mm512_madd_epi16(_ww0, _hh_cont); + __m512i _s1 = _mm512_madd_epi16(_ww1, _hh_cont); + _sum0 = _mm512_add_epi32(_sum0, _s0); + _sum1 = _mm512_add_epi32(_sum1, _s1); + + kptr += 64; + } + { + __m512i _tmp0 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m512i _tmp1 = _mm512_castps_si512(_mm512_shuffle_ps(_mm512_castsi512_ps(_sum0), _mm512_castsi512_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp0); + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _tmp1); + } +#endif // __AVX512VNNI__ + for (; i + 1 < num_output; i += 2) + { + __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); + + __m512i _ww = _mm512_cvtepi8_epi16(_w); + __m512i _hh_cont = _mm512_cvtepi8_epi16(_h_cont); + + __m512i _hh_cont0 = _mm512_shuffle_epi32(_hh_cont, _MM_PERM_AAAA); + +#if __AVX512VNNI__ + _lstm_IFOGh0 = _mm512_dpwssd_epi32(_lstm_IFOGh0, _ww, _hh_cont0); +#else + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _mm512_madd_epi16(_ww, _hh_cont0)); +#endif // __AVX512VNNI__ + + kptr += 32; + } + for (; i < num_output; i++) + { + __m128i _w = _mm_load_si128((const __m128i*)kptr); + __m256i _h_cont = _mm256_set1_epi16(hs[i]); + + __m256i _ww = _mm256_cvtepi8_epi16(_w); + + __m512i _s0 = _mm512_cvtepi16_epi32(_mm256_mullo_epi16(_ww, _h_cont)); + + _lstm_IFOGh0 = _mm512_add_epi32(_lstm_IFOGh0, _s0); + + kptr += 16; + } + + __m512 _descale_x = _mm512_set1_ps(descale_x); + __m512 _descale_h = _mm512_set1_ps(descale_h); + + __m512 _lstm_IFOG0 = _mm512_loadu_ps(bias_c_IFOG); + + __m512 _descale_xc_IFOG = _mm512_loadu_ps(descales_ptr); + + _lstm_IFOG0 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(_lstm_IFOGx0), _mm512_mul_ps(_descale_x, _descale_xc_IFOG), _lstm_IFOG0); + + __m512 _descale_hc_IFOG = _mm512_loadu_ps(descales_ptr + 16); + + _lstm_IFOG0 = _mm512_fmadd_ps(_mm512_cvtepi32_ps(_lstm_IFOGh0), _mm512_mul_ps(_descale_h, _descale_hc_IFOG), _lstm_IFOG0); + + _mm512_storeu_ps(gates_data, _lstm_IFOG0); + } + remain_hidden_size_start += nn_hidden_size << 2; + nn_hidden_size = (hidden_size - remain_hidden_size_start) >> 1; +#else + nn_hidden_size = hidden_size >> 1; +#endif // __AVX512F__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = remain_hidden_size_start + qq * 2; + + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + +#if __AVX512F__ + const signed char* kptr = weight_data_tm.row(q / 4 + (q % 4) / 2); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4 + (q % 4) / 2); +#else + const signed char* kptr = weight_data_tm.row(q / 2); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 2); +#endif + + float* gates_data = gates.row(q); + + __m256i _lstm_IFOGx0 = _mm256_setzero_si256(); + __m256i _sum0 = _mm256_setzero_si256(); + __m256i _sum1 = _mm256_setzero_si256(); + int i = 0; +#if __AVXVNNI__ || __AVX512VNNI__ + __m128i _v127q = _mm_set1_epi8(127); + __m256i _v127 = _mm256_set1_epi8(127); +#if defined(__x86_64__) || defined(_M_X64) + __m256i _sum2 = _mm256_setzero_si256(); + __m256i _sum3 = _mm256_setzero_si256(); + for (; i + 15 < size; i += 16) + { + __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + + _xi = _mm_add_epi8(_xi, _v127q); + __m256i _xii = _mm256_inserti128_si256(_mm256_castsi128_si256(_xi), _xi, 1); + + _sum0 = _mm256_dpbusd_epi32(_sum0, _xii, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _xii, _w1); + _sum2 = _mm256_dpbusd_epi32(_sum2, _xii, _w2); + _sum3 = _mm256_dpbusd_epi32(_sum3, _xii, _w3); + + kptr += 128; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + __m256i _xi = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(x + i))); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + + _xi = _mm256_add_epi8(_xi, _v127); + _sum0 = _mm256_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _xi, _w1); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } + + for (; i + 3 < size; i += 4) + { + __m256i _xi = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(x + i))); + __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); + + _xi = _mm256_add_epi8(_xi, _v127); + _lstm_IFOGx0 = _mm256_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); + + kptr += 32; + } + { + __m256i _w_shift = _mm256_loadu_si256((const __m256i*)kptr); + _lstm_IFOGx0 = _mm256_sub_epi32(_lstm_IFOGx0, _w_shift); + kptr += 32; + } +#else +#if defined(__x86_64__) || defined(_M_X64) + __m256i _sum2 = _mm256_setzero_si256(); + __m256i _sum3 = _mm256_setzero_si256(); + for (; i + 7 < size; i += 8) + { + __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + + __m256i _xii = _mm256_cvtepi8_epi16(_xi); + __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); + __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); + __m256i _ww2 = _mm256_cvtepi8_epi16(_w2); + __m256i _ww3 = _mm256_cvtepi8_epi16(_w3); + + __m256i _s0 = _mm256_madd_epi16(_ww0, _xii); + __m256i _s1 = _mm256_madd_epi16(_ww1, _xii); + __m256i _s2 = _mm256_madd_epi16(_ww2, _xii); + __m256i _s3 = _mm256_madd_epi16(_ww3, _xii); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < size; i += 4) + { + __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + + __m256i _xii = _mm256_cvtepi8_epi16(_xi); + __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); + __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); + + __m256i _s0 = _mm256_madd_epi16(_ww0, _xii); + __m256i _s1 = _mm256_madd_epi16(_ww1, _xii); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + + kptr += 32; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _tmp0); + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < size; i += 2) + { + __m128i _w = _mm_loadu_si128((const __m128i*)kptr); + __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); + + __m256i _ww = _mm256_cvtepi8_epi16(_w); + __m256i _xixi = _mm256_cvtepi8_epi16(_xi); + + __m256i _xixi0 = _mm256_shuffle_epi32(_xixi, _MM_SHUFFLE(0, 0, 0, 0)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _lstm_IFOGx0 = _mm256_dpwssd_epi32(_lstm_IFOGx0, _ww, _xixi0); +#else + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _mm256_madd_epi16(_ww, _xixi0)); +#endif // __AVXVNNI__ || __AVX512VNNI__ + + kptr += 16; + } + for (; i < size; i++) + { + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _xi = _mm_set1_epi16(x[i]); + + _w = _mm_cvtepi8_epi16(_w); + + __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_w, _xi)); + + _lstm_IFOGx0 = _mm256_add_epi32(_lstm_IFOGx0, _s0); + + kptr += 8; + } + + __m256i _lstm_IFOGh0 = _mm256_setzero_si256(); + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); + i = 0; +#if __AVXVNNI__ || __AVX512VNNI__ +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + for (; i + 15 < num_output; i += 16) + { + __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + __m256i _w2 = _mm256_loadu_si256((const __m256i*)(kptr + 64)); + __m256i _w3 = _mm256_loadu_si256((const __m256i*)(kptr + 96)); + + _h_cont = _mm_add_epi8(_h_cont, _v127q); + __m256i _hh_cont = _mm256_broadcastsi128_si256(_h_cont); + + _sum0 = _mm256_dpbusd_epi32(_sum0, _hh_cont, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _hh_cont, _w1); + _sum2 = _mm256_dpbusd_epi32(_sum2, _hh_cont, _w2); + _sum3 = _mm256_dpbusd_epi32(_sum3, _hh_cont, _w3); + + kptr += 128; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + __m256i _h_cont = _mm256_castpd_si256(_mm256_broadcast_sd((const double*)(hs + i))); + __m256i _w0 = _mm256_loadu_si256((const __m256i*)kptr); + __m256i _w1 = _mm256_loadu_si256((const __m256i*)(kptr + 32)); + + _h_cont = _mm256_add_epi8(_h_cont, _v127); + _sum0 = _mm256_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm256_dpbusd_epi32(_sum1, _h_cont, _w1); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } + + for (; i + 3 < num_output; i += 4) + { + __m256i _h_cont = _mm256_castps_si256(_mm256_broadcast_ss((const float*)(hs + i))); + __m256i _w = _mm256_loadu_si256((const __m256i*)kptr); + + _h_cont = _mm256_add_epi8(_h_cont, _v127); + _lstm_IFOGh0 = _mm256_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); + + kptr += 32; + } + { + __m256i _w_shift = _mm256_loadu_si256((const __m256i*)kptr); + _lstm_IFOGh0 = _mm256_sub_epi32(_lstm_IFOGh0, _w_shift); + kptr += 32; + } +#else +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm256_setzero_si256(); + _sum3 = _mm256_setzero_si256(); + for (; i + 7 < num_output; i += 8) + { + __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + + __m256i _hh_cont = _mm256_cvtepi8_epi16(_h_cont); + __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); + __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); + __m256i _ww2 = _mm256_cvtepi8_epi16(_w2); + __m256i _ww3 = _mm256_cvtepi8_epi16(_w3); + + __m256i _s0 = _mm256_madd_epi16(_ww0, _hh_cont); + __m256i _s1 = _mm256_madd_epi16(_ww1, _hh_cont); + __m256i _s2 = _mm256_madd_epi16(_ww2, _hh_cont); + __m256i _s3 = _mm256_madd_epi16(_ww3, _hh_cont); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + _sum2 = _mm256_add_epi32(_sum2, _s2); + _sum3 = _mm256_add_epi32(_sum3, _s3); + + kptr += 64; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + __m256i _tmp1 = _mm256_hadd_epi32(_sum2, _sum3); + _tmp0 = _mm256_hadd_epi32(_tmp0, _tmp1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } + + _sum0 = _mm256_setzero_si256(); + _sum1 = _mm256_setzero_si256(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < num_output; i += 4) + { + __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + + __m256i _hh_cont = _mm256_cvtepi8_epi16(_h_cont); + __m256i _ww0 = _mm256_cvtepi8_epi16(_w0); + __m256i _ww1 = _mm256_cvtepi8_epi16(_w1); + + __m256i _s0 = _mm256_madd_epi16(_ww0, _hh_cont); + __m256i _s1 = _mm256_madd_epi16(_ww1, _hh_cont); + _sum0 = _mm256_add_epi32(_sum0, _s0); + _sum1 = _mm256_add_epi32(_sum1, _s1); + + kptr += 32; + } + { + __m256i _tmp0 = _mm256_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _tmp0); + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < num_output; i += 2) + { + __m128i _w = _mm_loadu_si128((const __m128i*)kptr); + __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); + + __m256i _ww = _mm256_cvtepi8_epi16(_w); + __m256i _hh_cont = _mm256_cvtepi8_epi16(_h_cont); + + __m256i _hh_cont0 = _mm256_shuffle_epi32(_hh_cont, _MM_SHUFFLE(0, 0, 0, 0)); + +#if __AVXVNNI__ || __AVX512VNNI__ + _lstm_IFOGh0 = _mm256_dpwssd_epi32(_lstm_IFOGh0, _ww, _hh_cont0); +#else + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _mm256_madd_epi16(_ww, _hh_cont0)); +#endif // __AVXVNNI__ || __AVX512VNNI__ + + kptr += 16; + } + for (; i < num_output; i++) + { + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _h_cont = _mm_set1_epi16(hs[i]); + + _w = _mm_cvtepi8_epi16(_w); + + __m256i _s0 = _mm256_cvtepi16_epi32(_mm_mullo_epi16(_w, _h_cont)); + + _lstm_IFOGh0 = _mm256_add_epi32(_lstm_IFOGh0, _s0); + + kptr += 8; + } + + __m256 _descale_x = _mm256_set1_ps(descale_x); + __m256 _descale_h = _mm256_set1_ps(descale_h); + + __m256 _lstm_IFOG0 = _mm256_loadu_ps(bias_c_IFOG); + + __m256 _descale_xc_IFOG = _mm256_loadu_ps(descales_ptr); + + _lstm_IFOG0 = _mm256_comp_fmadd_ps(_mm256_cvtepi32_ps(_lstm_IFOGx0), _mm256_mul_ps(_descale_x, _descale_xc_IFOG), _lstm_IFOG0); + + __m256 _descale_hc_IFOG = _mm256_loadu_ps(descales_ptr + 8); + + _lstm_IFOG0 = _mm256_comp_fmadd_ps(_mm256_cvtepi32_ps(_lstm_IFOGh0), _mm256_mul_ps(_descale_h, _descale_hc_IFOG), _lstm_IFOG0); + + _mm256_storeu_ps(gates_data, _lstm_IFOG0); + } + remain_hidden_size_start += nn_hidden_size << 1; +#endif // __AVX2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const signed char* x = bottom_blob_int8.row(ti); + const signed char* hs = hidden_state_int8; + const float descale_x = bottom_blob_int8_descales[ti]; + const float descale_h = hidden_state_int8_descale; + + // gate reset update + const float* bias_c_IFOG = (const float*)bias_c + q * 4; + +#if __AVX512F__ + const signed char* kptr = weight_data_tm.row(q / 4 + (q % 4) / 2 + q % 2); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 4 + (q % 4) / 2 + q % 2); +#elif __AVX2__ + const signed char* kptr = weight_data_tm.row(q / 2 + q % 2); + const float* descales_ptr = weight_data_tm_int8_descales.row(q / 2 + q % 2); +#else + const signed char* kptr = weight_data_tm.row(q); + const float* descales_ptr = weight_data_tm_int8_descales.row(q); +#endif + + float* gates_data = gates.row(q); + +#if __SSE2__ + __m128i _lstm_IFOGx0 = _mm_setzero_si128(); + __m128i _sum0 = _mm_setzero_si128(); + __m128i _sum1 = _mm_setzero_si128(); + int i = 0; +#if __AVXVNNI__ || __AVX512VNNI__ + __m128i _v127 = _mm_set1_epi8(127); +#if defined(__x86_64__) || defined(_M_X64) + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); + for (; i + 15 < size; i += 16) + { + __m128i _xi = _mm_loadu_si128((const __m128i*)(x + i)); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + + _xi = _mm_add_epi8(_xi, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); + _sum2 = _mm_dpbusd_epi32(_sum2, _xi, _w2); + _sum3 = _mm_dpbusd_epi32(_sum3, _xi, _w3); + + kptr += 64; + } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < size; i += 8) + { + __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + + _xi = _mm_add_epi8(_xi, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _xi, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _xi, _w1); + + kptr += 32; + } + { + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp0); + } + + for (; i + 3 < size; i += 4) + { + __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); + __m128i _w = _mm_loadu_si128((const __m128i*)kptr); + + _xi = _mm_add_epi8(_xi, _v127); + _lstm_IFOGx0 = _mm_dpbusd_epi32(_lstm_IFOGx0, _xi, _w); + + kptr += 16; + } + { + __m128i _w_shift = _mm_loadu_si128((const __m128i*)kptr); + _lstm_IFOGx0 = _mm_sub_epi32(_lstm_IFOGx0, _w_shift); + kptr += 16; + } +#else +#if defined(__x86_64__) || defined(_M_X64) + __m128i _sum2 = _mm_setzero_si128(); + __m128i _sum3 = _mm_setzero_si128(); + for (; i + 7 < size; i += 8) + { + __m128i _xi = _mm_castpd_si128(_mm_load1_pd((const double*)(x + i))); + __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); + __m128i _w2 = _mm_loadl_epi64((const __m128i*)(kptr + 16)); + __m128i _w3 = _mm_loadl_epi64((const __m128i*)(kptr + 24)); + +#if __SSE4_1__ + _xi = _mm_cvtepi8_epi16(_xi); + _w0 = _mm_cvtepi8_epi16(_w0); + _w1 = _mm_cvtepi8_epi16(_w1); + _w2 = _mm_cvtepi8_epi16(_w2); + _w3 = _mm_cvtepi8_epi16(_w3); +#else + _xi = _mm_unpacklo_epi8(_xi, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi)); + _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); + _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); + _w2 = _mm_unpacklo_epi8(_w2, _mm_cmpgt_epi8(_mm_setzero_si128(), _w2)); + _w3 = _mm_unpacklo_epi8(_w3, _mm_cmpgt_epi8(_mm_setzero_si128(), _w3)); +#endif + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_w0, _xi, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _xi, _sum1); + _sum2 = _mm_maddd_epi16(_w2, _xi, _sum2); + _sum3 = _mm_maddd_epi16(_w3, _xi, _sum3); +#else + __m128i _s0 = _mm_madd_epi16(_w0, _xi); + __m128i _s1 = _mm_madd_epi16(_w1, _xi); + __m128i _s2 = _mm_madd_epi16(_w2, _xi); + __m128i _s3 = _mm_madd_epi16(_w3, _xi); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif + + kptr += 32; + } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum0); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum2); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < size; i += 4) + { + __m128i _xi = _mm_castps_si128(_mm_load1_ps((const float*)(x + i))); + __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); + +#if __SSE4_1__ + _xi = _mm_cvtepi8_epi16(_xi); + _w0 = _mm_cvtepi8_epi16(_w0); + _w1 = _mm_cvtepi8_epi16(_w1); +#else + _xi = _mm_unpacklo_epi8(_xi, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi)); + _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); + _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); +#endif + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_w0, _xi, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _xi, _sum1); +#else + __m128i _s0 = _mm_madd_epi16(_w0, _xi); + __m128i _s1 = _mm_madd_epi16(_w1, _xi); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); +#endif + + kptr += 16; + } + { +#if __SSSE3__ + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp0); +#else + __m128i _tmp0 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m128i _tmp1 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp0); + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _tmp1); +#endif // __SSSE3__ + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < size; i += 2) + { + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _xi = _mm_set1_epi16(((const short*)(x + i))[0]); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); + _xi = _mm_cvtepi8_epi16(_xi); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); + _xi = _mm_unpacklo_epi8(_xi, _mm_cmpgt_epi8(_mm_setzero_si128(), _xi)); +#endif + +#if __XOP__ + _lstm_IFOGx0 = _mm_maddd_epi16(_w, _xi, _lstm_IFOGx0); +#else + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _mm_madd_epi16(_w, _xi)); +#endif + + kptr += 8; + } + for (; i < size; i++) + { + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _xi = _mm_set1_epi16(x[i]); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); +#endif + +#if __XOP__ + _w = _mm_unpacklo_epi16(_w, _w); + + _lstm_IFOGx0 = _mm_maccd_epi16(_w, _xi, _lstm_IFOGx0); +#else + __m128i _sl = _mm_mullo_epi16(_w, _xi); + __m128i _sh = _mm_mulhi_epi16(_w, _xi); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + + _lstm_IFOGx0 = _mm_add_epi32(_lstm_IFOGx0, _s0); +#endif + + kptr += 4; + } + + __m128i _lstm_IFOGh0 = _mm_setzero_si128(); + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); + i = 0; +#if __AVXVNNI__ || __AVX512VNNI__ +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + for (; i + 15 < num_output; i += 16) + { + __m128i _h_cont = _mm_loadu_si128((const __m128i*)(hs + i)); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + __m128i _w2 = _mm_loadu_si128((const __m128i*)(kptr + 32)); + __m128i _w3 = _mm_loadu_si128((const __m128i*)(kptr + 48)); + + _h_cont = _mm_add_epi8(_h_cont, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); + _sum2 = _mm_dpbusd_epi32(_sum2, _h_cont, _w2); + _sum3 = _mm_dpbusd_epi32(_sum3, _h_cont, _w3); + + kptr += 64; + } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 7 < num_output; i += 8) + { + __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); + __m128i _w0 = _mm_loadu_si128((const __m128i*)kptr); + __m128i _w1 = _mm_loadu_si128((const __m128i*)(kptr + 16)); + + _h_cont = _mm_add_epi8(_h_cont, _v127); + _sum0 = _mm_dpbusd_epi32(_sum0, _h_cont, _w0); + _sum1 = _mm_dpbusd_epi32(_sum1, _h_cont, _w1); + + kptr += 32; + } + { + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp0); + } + + for (; i + 3 < num_output; i += 4) + { + __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); + __m128i _w = _mm_loadu_si128((const __m128i*)kptr); + + _h_cont = _mm_add_epi8(_h_cont, _v127); + _lstm_IFOGh0 = _mm_dpbusd_epi32(_lstm_IFOGh0, _h_cont, _w); + + kptr += 16; + } + { + __m128i _w_shift = _mm_loadu_si128((const __m128i*)kptr); + _lstm_IFOGh0 = _mm_sub_epi32(_lstm_IFOGh0, _w_shift); + kptr += 16; + } +#else +#if defined(__x86_64__) || defined(_M_X64) + _sum2 = _mm_setzero_si128(); + _sum3 = _mm_setzero_si128(); + for (; i + 7 < num_output; i += 8) + { + __m128i _h_cont = _mm_castpd_si128(_mm_load1_pd((const double*)(hs + i))); + __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); + __m128i _w2 = _mm_loadl_epi64((const __m128i*)(kptr + 16)); + __m128i _w3 = _mm_loadl_epi64((const __m128i*)(kptr + 24)); + +#if __SSE4_1__ + _h_cont = _mm_cvtepi8_epi16(_h_cont); + _w0 = _mm_cvtepi8_epi16(_w0); + _w1 = _mm_cvtepi8_epi16(_w1); + _w2 = _mm_cvtepi8_epi16(_w2); + _w3 = _mm_cvtepi8_epi16(_w3); +#else + _h_cont = _mm_unpacklo_epi8(_h_cont, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont)); + _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); + _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); + _w2 = _mm_unpacklo_epi8(_w2, _mm_cmpgt_epi8(_mm_setzero_si128(), _w2)); + _w3 = _mm_unpacklo_epi8(_w3, _mm_cmpgt_epi8(_mm_setzero_si128(), _w3)); +#endif + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_w0, _h_cont, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _h_cont, _sum1); + _sum2 = _mm_maddd_epi16(_w2, _h_cont, _sum2); + _sum3 = _mm_maddd_epi16(_w3, _h_cont, _sum3); +#else + __m128i _s0 = _mm_madd_epi16(_w0, _h_cont); + __m128i _s1 = _mm_madd_epi16(_w1, _h_cont); + __m128i _s2 = _mm_madd_epi16(_w2, _h_cont); + __m128i _s3 = _mm_madd_epi16(_w3, _h_cont); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); + _sum2 = _mm_add_epi32(_sum2, _s2); + _sum3 = _mm_add_epi32(_sum3, _s3); +#endif + + kptr += 32; + } + { + transpose4x4_epi32(_sum0, _sum1, _sum2, _sum3); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum0); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum2); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _sum3); + } + + _sum0 = _mm_setzero_si128(); + _sum1 = _mm_setzero_si128(); +#endif // defined(__x86_64__) || defined(_M_X64) + for (; i + 3 < num_output; i += 4) + { + __m128i _h_cont = _mm_castps_si128(_mm_load1_ps((const float*)(hs + i))); + __m128i _w0 = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _w1 = _mm_loadl_epi64((const __m128i*)(kptr + 8)); + +#if __SSE4_1__ + _h_cont = _mm_cvtepi8_epi16(_h_cont); + _w0 = _mm_cvtepi8_epi16(_w0); + _w1 = _mm_cvtepi8_epi16(_w1); +#else + _h_cont = _mm_unpacklo_epi8(_h_cont, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont)); + _w0 = _mm_unpacklo_epi8(_w0, _mm_cmpgt_epi8(_mm_setzero_si128(), _w0)); + _w1 = _mm_unpacklo_epi8(_w1, _mm_cmpgt_epi8(_mm_setzero_si128(), _w1)); +#endif + +#if __XOP__ + _sum0 = _mm_maddd_epi16(_w0, _h_cont, _sum0); + _sum1 = _mm_maddd_epi16(_w1, _h_cont, _sum1); +#else + __m128i _s0 = _mm_madd_epi16(_w0, _h_cont); + __m128i _s1 = _mm_madd_epi16(_w1, _h_cont); + _sum0 = _mm_add_epi32(_sum0, _s0); + _sum1 = _mm_add_epi32(_sum1, _s1); +#endif + + kptr += 16; + } + { +#if __SSSE3__ + __m128i _tmp0 = _mm_hadd_epi32(_sum0, _sum1); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp0); +#else + __m128i _tmp0 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(2, 0, 2, 0))); + __m128i _tmp1 = _mm_castps_si128(_mm_shuffle_ps(_mm_castsi128_ps(_sum0), _mm_castsi128_ps(_sum1), _MM_SHUFFLE(3, 1, 3, 1))); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp0); + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _tmp1); +#endif // __SSSE3__ + } +#endif // __AVXVNNI__ || __AVX512VNNI__ + for (; i + 1 < num_output; i += 2) + { + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _h_cont = _mm_set1_epi16(((const short*)(hs + i))[0]); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); + _h_cont = _mm_cvtepi8_epi16(_h_cont); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); + _h_cont = _mm_unpacklo_epi8(_h_cont, _mm_cmpgt_epi8(_mm_setzero_si128(), _h_cont)); +#endif + +#if __XOP__ + _lstm_IFOGh0 = _mm_maddd_epi16(_w, _h_cont, _lstm_IFOGh0); +#else + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _mm_madd_epi16(_w, _h_cont)); +#endif + + kptr += 8; + } + for (; i < num_output; i++) + { + __m128i _w = _mm_loadl_epi64((const __m128i*)kptr); + __m128i _h_cont = _mm_set1_epi16(hs[i]); + +#if __SSE4_1__ + _w = _mm_cvtepi8_epi16(_w); +#else + _w = _mm_unpacklo_epi8(_w, _mm_cmpgt_epi8(_mm_setzero_si128(), _w)); +#endif + +#if __XOP__ + _w = _mm_unpacklo_epi16(_w, _w); + + _lstm_IFOGh0 = _mm_maccd_epi16(_w, _h_cont, _lstm_IFOGh0); +#else + __m128i _sl = _mm_mullo_epi16(_w, _h_cont); + __m128i _sh = _mm_mulhi_epi16(_w, _h_cont); + __m128i _s0 = _mm_unpacklo_epi16(_sl, _sh); + + _lstm_IFOGh0 = _mm_add_epi32(_lstm_IFOGh0, _s0); +#endif + + kptr += 4; + } + + __m128 _descale_x = _mm_set1_ps(descale_x); + __m128 _descale_h = _mm_set1_ps(descale_h); + + __m128 _lstm_IFOG0 = _mm_loadu_ps(bias_c_IFOG); + + __m128 _descale_xc_IFOG = _mm_loadu_ps(descales_ptr); + + _lstm_IFOG0 = _mm_comp_fmadd_ps(_mm_cvtepi32_ps(_lstm_IFOGx0), _mm_mul_ps(_descale_x, _descale_xc_IFOG), _lstm_IFOG0); + + __m128 _descale_hc_IFOG = _mm_loadu_ps(descales_ptr + 4); + + _lstm_IFOG0 = _mm_comp_fmadd_ps(_mm_cvtepi32_ps(_lstm_IFOGh0), _mm_mul_ps(_descale_h, _descale_hc_IFOG), _lstm_IFOG0); + + _mm_storeu_ps(gates_data, _lstm_IFOG0); +#else + int Ix = 0; + int Fx = 0; + int Ox = 0; + int Gx = 0; + for (int i = 0; i < size; i++) + { + signed char xi = x[i]; + + Ix += kptr[0] * xi; + Fx += kptr[1] * xi; + Ox += kptr[2] * xi; + Gx += kptr[3] * xi; + + kptr += 4; + } + + int Ih = 0; + int Fh = 0; + int Oh = 0; + int Gh = 0; + for (int i = 0; i < num_output; i++) + { + signed char h_cont = hs[i]; + + Ih += kptr[0] * h_cont; + Fh += kptr[1] * h_cont; + Oh += kptr[2] * h_cont; + Gh += kptr[3] * h_cont; + + kptr += 4; + } + + const float descale_xc_I = descales_ptr[0]; + const float descale_xc_F = descales_ptr[1]; + const float descale_xc_O = descales_ptr[2]; + const float descale_xc_G = descales_ptr[3]; + const float descale_hc_I = descales_ptr[4]; + const float descale_hc_F = descales_ptr[5]; + const float descale_hc_O = descales_ptr[6]; + const float descale_hc_G = descales_ptr[7]; + + float I = bias_c_IFOG[0] + Ix * (descale_x * descale_xc_I) + Ih * (descale_h * descale_hc_I); + float F = bias_c_IFOG[1] + Fx * (descale_x * descale_xc_F) + Fh * (descale_h * descale_hc_F); + float O = bias_c_IFOG[2] + Ox * (descale_x * descale_xc_O) + Oh * (descale_h * descale_hc_O); + float G = bias_c_IFOG[3] + Gx * (descale_x * descale_xc_G) + Gh * (descale_h * descale_hc_G); + + gates_data[0] = I; + gates_data[1] = F; + gates_data[2] = O; + gates_data[3] = G; +#endif // __SSE2__ + } + + // lstm unit + // sigmoid(I) + // sigmoid(F) + // sigmoid(O) + // tanh(G) + // c_t := f_t .* c_{t-1} + i_t .* g_t + // h_t := o_t .* tanh[c_t] + float* output_data = top_blob.row(ti); + + float* cell_ptr = cell_state; + float* hidden_ptr = hidden_state; + float* tmp_hidden_ptr = tmp_hidden_state; + + remain_hidden_size_start = 0; +#if __SSE2__ +#if __AVX__ +#if __AVX512F__ + nn_hidden_size = hidden_size >> 4; + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = qq * 16; + + const float* gates_data = gates.row(q); + + __m512 _IFOG_0 = _mm512_loadu_ps(gates_data); + __m512 _IFOG_1 = _mm512_loadu_ps(gates_data + 16); + __m512 _IFOG_2 = _mm512_loadu_ps(gates_data + 32); + __m512 _IFOG_3 = _mm512_loadu_ps(gates_data + 48); + + __m512 _tmp0 = _mm512_shuffle_f32x4(_IFOG_0, _IFOG_1, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp1 = _mm512_shuffle_f32x4(_IFOG_2, _IFOG_3, _MM_SHUFFLE(1, 0, 1, 0)); + __m512 _tmp2 = _mm512_shuffle_f32x4(_IFOG_0, _IFOG_1, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _tmp3 = _mm512_shuffle_f32x4(_IFOG_2, _IFOG_3, _MM_SHUFFLE(3, 2, 3, 2)); + __m512 _lstm_I = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _lstm_F = _mm512_shuffle_f32x4(_tmp0, _tmp1, _MM_SHUFFLE(3, 1, 3, 1)); + __m512 _lstm_O = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(2, 0, 2, 0)); + __m512 _lstm_G = _mm512_shuffle_f32x4(_tmp2, _tmp3, _MM_SHUFFLE(3, 1, 3, 1)); + + _lstm_I = sigmoid_avx512(_lstm_I); + _lstm_F = sigmoid_avx512(_lstm_F); + _lstm_O = sigmoid_avx512(_lstm_O); + _lstm_G = tanh_avx512(_lstm_G); + + __m512 _cell2 = _mm512_add_ps(_mm512_mul_ps(_lstm_F, _mm512_loadu_ps(cell_ptr + q)), _mm512_mul_ps(_lstm_I, _lstm_G)); + __m512 _lstm_H = _mm512_mul_ps(_lstm_O, tanh_avx512(_cell2)); + + _mm512_storeu_ps(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + _mm512_storeu_ps(hidden_ptr + q, _lstm_H); + _mm512_storeu_ps(output_data + q, _lstm_H); + } + else + { + _mm512_storeu_ps(tmp_hidden_ptr + q, _lstm_H); + } + } + remain_hidden_size_start += nn_hidden_size << 4; + nn_hidden_size = (hidden_size - remain_hidden_size_start) >> 3; +#else + nn_hidden_size = hidden_size >> 3; +#endif // __AVX512F__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = remain_hidden_size_start + qq * 8; + + const float* gates_data = gates.row(q); + + __m256 _IFOG_0 = _mm256_loadu_ps(gates_data); + __m256 _IFOG_1 = _mm256_loadu_ps(gates_data + 8); + __m256 _IFOG_2 = _mm256_loadu_ps(gates_data + 16); + __m256 _IFOG_3 = _mm256_loadu_ps(gates_data + 24); + +#if __AVX512F__ + __m256 _lstm_I = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _lstm_F = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _lstm_O = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _lstm_G = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 3, 0, 1)); +#else + // unzip4 + __m256 _tmp0 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp1 = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 2, 0, 0)); + __m256 _tmp2 = _mm256_permute2f128_ps(_IFOG_0, _IFOG_2, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp3 = _mm256_permute2f128_ps(_IFOG_1, _IFOG_3, _MM_SHUFFLE(0, 3, 0, 1)); + __m256 _tmp4 = _mm256_unpacklo_ps(_tmp0, _tmp1); + __m256 _tmp5 = _mm256_unpacklo_ps(_tmp2, _tmp3); + __m256 _tmp6 = _mm256_unpackhi_ps(_tmp0, _tmp1); + __m256 _tmp7 = _mm256_unpackhi_ps(_tmp2, _tmp3); + __m256 _lstm_I = _mm256_unpacklo_ps(_tmp4, _tmp5); + __m256 _lstm_F = _mm256_unpackhi_ps(_tmp4, _tmp5); + __m256 _lstm_O = _mm256_unpacklo_ps(_tmp6, _tmp7); + __m256 _lstm_G = _mm256_unpackhi_ps(_tmp6, _tmp7); +#endif + + _lstm_I = sigmoid_avx(_lstm_I); + _lstm_F = sigmoid_avx(_lstm_F); + _lstm_O = sigmoid_avx(_lstm_O); + _lstm_G = tanh_avx(_lstm_G); + + __m256 _cell2 = _mm256_add_ps(_mm256_mul_ps(_lstm_F, _mm256_loadu_ps(cell_ptr + q)), _mm256_mul_ps(_lstm_I, _lstm_G)); + __m256 _lstm_H = _mm256_mul_ps(_lstm_O, tanh_avx(_cell2)); + + _mm256_storeu_ps(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + _mm256_storeu_ps(hidden_ptr + q, _lstm_H); + _mm256_storeu_ps(output_data + q, _lstm_H); + } + else + { + _mm256_storeu_ps(tmp_hidden_ptr + q, _lstm_H); + } + } + remain_hidden_size_start += nn_hidden_size << 3; + nn_hidden_size = (hidden_size - remain_hidden_size_start) >> 2; +#else + nn_hidden_size = hidden_size >> 2; +#endif // __AVX__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int qq = 0; qq < nn_hidden_size; qq++) + { + int q = remain_hidden_size_start + qq * 4; + + const float* gates_data = gates.row(q); + + __m128 _lstm_I = _mm_loadu_ps(gates_data); + __m128 _lstm_F = _mm_loadu_ps(gates_data + 4); + __m128 _lstm_O = _mm_loadu_ps(gates_data + 8); + __m128 _lstm_G = _mm_loadu_ps(gates_data + 12); + +#if !__AVX512F__ + _MM_TRANSPOSE4_PS(_lstm_I, _lstm_F, _lstm_O, _lstm_G); +#endif + + _lstm_I = sigmoid_sse(_lstm_I); + _lstm_F = sigmoid_sse(_lstm_F); + _lstm_O = sigmoid_sse(_lstm_O); + _lstm_G = tanh_sse(_lstm_G); + + __m128 _cell2 = _mm_add_ps(_mm_mul_ps(_lstm_F, _mm_loadu_ps(cell_ptr + q)), _mm_mul_ps(_lstm_I, _lstm_G)); + __m128 _lstm_H = _mm_mul_ps(_lstm_O, tanh_sse(_cell2)); + + _mm_storeu_ps(cell_ptr + q, _cell2); + + if (num_output == hidden_size) + { + _mm_storeu_ps(hidden_ptr + q, _lstm_H); + _mm_storeu_ps(output_data + q, _lstm_H); + } + else + { + _mm_storeu_ps(tmp_hidden_ptr + q, _lstm_H); + } + } + remain_hidden_size_start += nn_hidden_size << 2; +#endif // __SSE2__ + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_hidden_size_start; q < hidden_size; q++) + { + const float* gates_data = gates.row(q); + + float I = gates_data[0]; + float F = gates_data[1]; + float O = gates_data[2]; + float G = gates_data[3]; + + I = 1.f / (1.f + expf(-I)); + F = 1.f / (1.f + expf(-F)); + O = 1.f / (1.f + expf(-O)); + G = tanhf(G); + + float cell2 = F * cell_ptr[q] + I * G; + float H = O * tanhf(cell2); + + cell_ptr[q] = cell2; + if (num_output == hidden_size) + { + hidden_ptr[q] = H; + output_data[q] = H; + } + else + { + tmp_hidden_ptr[q] = H; + } + } + + if (num_output != hidden_size) + { + // int nn_num_output = num_output >> 2; + // int remain_num_output_start = nn_num_output << 2; + // #pragma omp parallel for num_threads(opt.num_threads) + // for (int qq = 0; qq < nn_num_output; qq++) + // { + // int q = qq * 4; + // + // } + int remain_num_output_start = 0; + #pragma omp parallel for num_threads(opt.num_threads) + for (int q = remain_num_output_start; q < num_output; q++) + { + const float* hr = weight_hr.row(q); + const float* tmp_hidden_ptr = tmp_hidden_state; + + float H = 0; + for (int i = 0; i < hidden_size; i++) + { + H += tmp_hidden_ptr[i] * hr[i]; + } + + hidden_ptr[q] = H; + output_data[q] = H; + } + } + } +} diff --git a/src/layer/x86/lstm_x86.cpp b/src/layer/x86/lstm_x86.cpp index 6ba218e53d3..227f8b96a6a 100644 --- a/src/layer/x86/lstm_x86.cpp +++ b/src/layer/x86/lstm_x86.cpp @@ -24,10 +24,12 @@ #include "x86_activation.h" #include "x86_usability.h" -#include "layer_type.h" +#include "cpu.h" namespace ncnn { +#include "lstm_int8.h" + LSTM_x86::LSTM_x86() { one_blob_only = false; @@ -36,6 +38,13 @@ LSTM_x86::LSTM_x86() int LSTM_x86::create_pipeline(const Option& opt) { +#if NCNN_INT8 + if (int8_scale_term) + { + return create_pipeline_int8(opt); + } +#endif + // pack IFOG int num_directions = direction == 2 ? 2 : 1; int size = weight_data_size / num_directions / hidden_size / 4; @@ -560,6 +569,13 @@ static int lstm(const Mat& bottom_blob, Mat& top_blob, int reverse, const Mat& w int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blob, top_blob, opt); + } +#endif + int T = bottom_blob.h; int num_directions = direction == 2 ? 2 : 1; @@ -597,16 +613,20 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) if (top_blob_reverse.empty()) return -100; - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); - if (ret0 != 0) - return ret0; + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + if (ret != 0) + return ret; + } hidden.fill(0.0f); cell.fill(0.0f); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); - if (ret1 != 0) - return ret1; + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + if (ret != 0) + return ret; + } // concat w for (int i = 0; i < T; i++) @@ -625,6 +645,13 @@ int LSTM_x86::forward(const Mat& bottom_blob, Mat& top_blob, const Option& opt) int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const { +#if NCNN_INT8 + if (int8_scale_term) + { + return forward_int8(bottom_blobs, top_blobs, opt); + } +#endif + const Mat& bottom_blob = bottom_blobs[0]; int T = bottom_blob.h; int num_directions = direction == 2 ? 2 : 1; @@ -675,15 +702,233 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to Mat hidden0 = hidden.row_range(0, 1); Mat cell0 = cell.row_range(0, 1); - int ret0 = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); - if (ret0 != 0) - return ret0; + { + int ret = lstm(bottom_blob, top_blob_forward, 0, weight_xc_data_packed.channel(0), bias_c_data_packed.channel(0), weight_hc_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + if (ret != 0) + return ret; + } + + Mat hidden1 = hidden.row_range(1, 1); + Mat cell1 = cell.row_range(1, 1); + { + int ret = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + if (ret != 0) + return ret; + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + if (top_blobs.size() == 3) + { + top_blobs[1] = hidden; + top_blobs[2] = cell; + } + + return 0; +} + +#if NCNN_INT8 +int LSTM_x86::create_pipeline_int8(const Option& opt) +{ + // pack IFOG + const int num_directions = direction == 2 ? 2 : 1; + const int size = weight_data_size / num_directions / hidden_size / 4; + + lstm_transform_weight_int8(weight_xc_data, weight_xc_data_int8_scales, weight_hc_data, weight_hc_data_int8_scales, bias_c_data, weight_data_tm, weight_data_tm_int8_descales, bias_c_data_packed, size, num_output, num_directions, hidden_size, opt); + + if (opt.lightmode) + { + weight_xc_data.release(); + bias_c_data.release(); + weight_hc_data.release(); + weight_xc_data_int8_scales.release(); + weight_hc_data_int8_scales.release(); + } + + return 0; +} + +void LSTM_x86::dynamic_quantize(const Mat& bottom_blob, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const +{ + int size = bottom_blob.w; + int T = bottom_blob.h; + + // dynamic quantize bottom_blob + bottom_blob_int8_descales.create(T, (size_t)4u, 1, opt.blob_allocator); + + Mat bottom_blob_int8_scales(T, (size_t)4u, 1, opt.blob_allocator); + + // fp32 + for (int t = 0; t < T; t++) + { + const float* x = bottom_blob.row(t); + + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(x[i])); + } + + bottom_blob_int8_scales[t] = 127.f / absmax; + bottom_blob_int8_descales[t] = absmax / 127.f; + } + + quantize_to_int8(bottom_blob, bottom_blob_int8, bottom_blob_int8_scales, opt); +} + +int LSTM_x86::forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const +{ + int T = bottom_blob.h; + + int num_directions = direction == 2 ? 2 : 1; + + // initial hidden state + Mat hidden(num_output, 4u, opt.workspace_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + Mat cell(hidden_size, 4u, opt.workspace_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + } + + hidden.fill(0.f); + cell.fill(0.0f); + + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden, cell, opt); + } + + // concat w + for (int i = 0; i < T; i++) + { + const float* pf = top_blob_forward.row(i); + const float* pr = top_blob_reverse.row(i); + float* ptr = top_blob.row(i); + + memcpy(ptr, pf, num_output * sizeof(float)); + memcpy(ptr + num_output, pr, num_output * sizeof(float)); + } + } + + return 0; +} + +int LSTM_x86::forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const +{ + const Mat& bottom_blob = bottom_blobs[0]; + + int T = bottom_blob.h; + int num_directions = direction == 2 ? 2 : 1; + + Mat hidden; + Mat cell; + Allocator* hidden_cell_allocator = top_blobs.size() == 3 ? opt.blob_allocator : opt.workspace_allocator; + if (bottom_blobs.size() == 3) + { + hidden = bottom_blobs[1].clone(hidden_cell_allocator); + cell = bottom_blobs[2].clone(hidden_cell_allocator); + } + else + { + hidden.create(num_output, num_directions, 4u, hidden_cell_allocator); + if (hidden.empty()) + return -100; + hidden.fill(0.f); + + cell.create(hidden_size, num_directions, 4u, hidden_cell_allocator); + if (cell.empty()) + return -100; + cell.fill(0.f); + } + + Mat& top_blob = top_blobs[0]; + top_blob.create(num_output * num_directions, T, 4u, opt.blob_allocator); + if (top_blob.empty()) + return -100; + + // dynamic quantize bottom_blob + Mat bottom_blob_int8; + Mat bottom_blob_int8_descales; + { + Option opt_quant = opt; + opt_quant.blob_allocator = opt.workspace_allocator; + opt_quant.use_packing_layout = false; + dynamic_quantize(bottom_blob, bottom_blob_int8, bottom_blob_int8_descales, opt_quant); + } + + // Uni directional + if (direction == 0 || direction == 1) + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, direction, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden, cell, opt); + } + + if (direction == 2) + { + Mat top_blob_forward(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_forward.empty()) + return -100; + + Mat top_blob_reverse(num_output, T, 4u, opt.workspace_allocator); + if (top_blob_reverse.empty()) + return -100; + + Mat hidden0 = hidden.row_range(0, 1); + Mat cell0 = cell.row_range(0, 1); + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_forward, 0, weight_data_tm.channel(0), weight_data_tm_int8_descales.channel(0), bias_c_data_packed.channel(0), num_output == hidden_size ? Mat() : weight_hr_data.channel(0), hidden0, cell0, opt); + } Mat hidden1 = hidden.row_range(1, 1); Mat cell1 = cell.row_range(1, 1); - int ret1 = lstm(bottom_blob, top_blob_reverse, 1, weight_xc_data_packed.channel(1), bias_c_data_packed.channel(1), weight_hc_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); - if (ret1 != 0) - return ret1; + { + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob_reverse, 1, weight_data_tm.channel(1), weight_data_tm_int8_descales.channel(1), bias_c_data_packed.channel(1), num_output == hidden_size ? Mat() : weight_hr_data.channel(1), hidden1, cell1, opt); + } // concat w for (int i = 0; i < T; i++) @@ -705,5 +950,6 @@ int LSTM_x86::forward(const std::vector& bottom_blobs, std::vector& to return 0; } +#endif // NCNN_INT8 } // namespace ncnn diff --git a/src/layer/x86/lstm_x86.h b/src/layer/x86/lstm_x86.h index 1dc56d45e03..d31b7377ccf 100644 --- a/src/layer/x86/lstm_x86.h +++ b/src/layer/x86/lstm_x86.h @@ -30,10 +30,24 @@ class LSTM_x86 : public LSTM virtual int forward(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +protected: +#if NCNN_INT8 + int create_pipeline_int8(const Option& opt); + void dynamic_quantize(const Mat& bottom_blob, Mat& bottom_blob_int8, Mat& bottom_blob_int8_descales, const Option& opt) const; + int forward_int8(const Mat& bottom_blob, Mat& top_blob, const Option& opt) const; + int forward_int8(const std::vector& bottom_blobs, std::vector& top_blobs, const Option& opt) const; +#endif + public: Mat weight_xc_data_packed; Mat bias_c_data_packed; Mat weight_hc_data_packed; + + Mat weight_data_tm; + +#if NCNN_INT8 + Mat weight_data_tm_int8_descales; +#endif }; } // namespace ncnn diff --git a/src/layer/x86/lstm_x86_avx2.cpp b/src/layer/x86/lstm_x86_avx2.cpp new file mode 100644 index 00000000000..2029b46b0c0 --- /dev/null +++ b/src/layer/x86/lstm_x86_avx2.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "x86_activation.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "lstm_int8.h" + +void lstm_transform_weight_int8_avx2(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt) +{ + lstm_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); +} + +void lstm_int8_avx2(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/lstm_x86_avx512vnni.cpp b/src/layer/x86/lstm_x86_avx512vnni.cpp new file mode 100644 index 00000000000..656dcbb07de --- /dev/null +++ b/src/layer/x86/lstm_x86_avx512vnni.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "x86_activation.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "lstm_int8.h" + +void lstm_transform_weight_int8_avx512vnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt) +{ + lstm_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); +} + +void lstm_int8_avx512vnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/lstm_x86_avxvnni.cpp b/src/layer/x86/lstm_x86_avxvnni.cpp new file mode 100644 index 00000000000..925d781e877 --- /dev/null +++ b/src/layer/x86/lstm_x86_avxvnni.cpp @@ -0,0 +1,35 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "x86_activation.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "lstm_int8.h" + +void lstm_transform_weight_int8_avxvnni(const Mat& weight_xc, const Mat& weight_xc_int8_scales, const Mat& weight_hc, const Mat& weight_hc_int8_scales, const Mat& bias_c, Mat& weight_data_tm, Mat& weight_data_tm_int8_descales, Mat& bias_c_tm, int size, int num_output, int num_directions, int hidden_size, const Option& opt) +{ + lstm_transform_weight_int8(weight_xc, weight_xc_int8_scales, weight_hc, weight_hc_int8_scales, bias_c, weight_data_tm, weight_data_tm_int8_descales, bias_c_tm, size, num_output, num_directions, hidden_size, opt); +} + +void lstm_int8_avxvnni(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/lstm_x86_xop.cpp b/src/layer/x86/lstm_x86_xop.cpp new file mode 100644 index 00000000000..5345b3dd433 --- /dev/null +++ b/src/layer/x86/lstm_x86_xop.cpp @@ -0,0 +1,30 @@ +// Tencent is pleased to support the open source community by making ncnn available. +// +// Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved. +// +// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except +// in compliance with the License. You may obtain a copy of the License at +// +// https://opensource.org/licenses/BSD-3-Clause +// +// Unless required by applicable law or agreed to in writing, software distributed +// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR +// CONDITIONS OF ANY KIND, either express or implied. See the License for the +// specific language governing permissions and limitations under the License. + +#include "cpu.h" +#include "mat.h" +#include "layer.h" +#include "x86_activation.h" +#include "x86_usability.h" + +namespace ncnn { + +#include "lstm_int8.h" + +void lstm_int8_xop(const Mat& bottom_blob_int8, const Mat& bottom_blob_int8_descales, Mat& top_blob, int reverse, const Mat& weight_data_tm, const Mat& weight_data_tm_int8_descales, const Mat& bias_c, const Mat& weight_hr, Mat& hidden_state, Mat& cell_state, const Option& opt) +{ + lstm_int8(bottom_blob_int8, bottom_blob_int8_descales, top_blob, reverse, weight_data_tm, weight_data_tm_int8_descales, bias_c, weight_hr, hidden_state, cell_state, opt); +} + +} // namespace ncnn diff --git a/src/layer/x86/x86_usability.h b/src/layer/x86/x86_usability.h index 583f4b10e7e..c838eb0c723 100644 --- a/src/layer/x86/x86_usability.h +++ b/src/layer/x86/x86_usability.h @@ -18,6 +18,8 @@ #include #if __SSE2__ #include +#if __SSSE3__ +#include #if __SSE4_1__ #include #if __AVX__ @@ -31,6 +33,7 @@ #endif #endif #endif +#endif #endif // __SSE2__ static NCNN_FORCEINLINE signed char float2int8(float v) diff --git a/tests/test_gru.cpp b/tests/test_gru.cpp index 487daeb3a27..124ad5434ab 100644 --- a/tests/test_gru.cpp +++ b/tests/test_gru.cpp @@ -32,13 +32,13 @@ static int test_gru(const ncnn::Mat& a, int outch, int direction) int ret = test_layer("GRU", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_gru failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_gru failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_gru_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) +static int test_gru_with_hidden(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -63,13 +63,13 @@ int test_gru_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) int ret = test_layer("GRU", pd, weights, as, 2); if (ret != 0) { - fprintf(stderr, "test_gru_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_gru_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_gru_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +static int test_gru_with_hidden_input(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -94,13 +94,13 @@ int test_gru_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directio int ret = test_layer("GRU", pd, weights, as, 1); if (ret != 0) { - fprintf(stderr, "test_gru_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_gru_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_gru_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +static int test_gru_with_hidden_output(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -121,7 +121,7 @@ int test_gru_layer_with_hidden_output(const ncnn::Mat& a, int outch, int directi int ret = test_layer("GRU", pd, weights, as, 2); if (ret != 0) { - fprintf(stderr, "test_gru_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_gru_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; @@ -138,86 +138,87 @@ static int test_gru_0() || test_gru(RandomMat(5, 16), 16, 2) || test_gru(RandomMat(3, 16), 8, 2) || test_gru(RandomMat(8, 16), 16, 2) + || test_gru(RandomMat(31, 3), 31, 2) || test_gru(RandomMat(2, 5), 17, 2); } static int test_gru_1() { return 0 - || test_gru_layer_with_hidden(RandomMat(4, 4), 1, 2) - || test_gru_layer_with_hidden(RandomMat(8, 2), 2, 2) - || test_gru_layer_with_hidden(RandomMat(16, 8), 7, 2) - || test_gru_layer_with_hidden(RandomMat(17, 8), 8, 2) - || test_gru_layer_with_hidden(RandomMat(19, 15), 8, 2) - || test_gru_layer_with_hidden(RandomMat(5, 16), 16, 2) - || test_gru_layer_with_hidden(RandomMat(3, 16), 8, 2) - || test_gru_layer_with_hidden(RandomMat(2, 5), 99, 2) - || test_gru_layer_with_hidden(RandomMat(4, 4), 1, 1) - || test_gru_layer_with_hidden(RandomMat(8, 2), 2, 1) - || test_gru_layer_with_hidden(RandomMat(16, 8), 7, 1) - || test_gru_layer_with_hidden(RandomMat(17, 8), 8, 1) - || test_gru_layer_with_hidden(RandomMat(19, 15), 8, 1) - || test_gru_layer_with_hidden(RandomMat(5, 16), 16, 1) - || test_gru_layer_with_hidden(RandomMat(3, 16), 8, 1) - || test_gru_layer_with_hidden(RandomMat(2, 5), 99, 1) - || test_gru_layer_with_hidden(RandomMat(4, 2), 1, 0) - || test_gru_layer_with_hidden(RandomMat(8, 2), 2, 0) - || test_gru_layer_with_hidden(RandomMat(16, 8), 7, 0) - || test_gru_layer_with_hidden(RandomMat(17, 8), 8, 0) - || test_gru_layer_with_hidden(RandomMat(19, 15), 8, 0) - || test_gru_layer_with_hidden(RandomMat(5, 16), 16, 0) - || test_gru_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_gru_layer_with_hidden(RandomMat(2, 5), 17, 0) - - || test_gru_layer_with_hidden_input(RandomMat(4, 4), 1, 2) - || test_gru_layer_with_hidden_input(RandomMat(8, 2), 2, 2) - || test_gru_layer_with_hidden_input(RandomMat(16, 8), 7, 2) - || test_gru_layer_with_hidden_input(RandomMat(17, 8), 8, 2) - || test_gru_layer_with_hidden_input(RandomMat(19, 15), 8, 2) - || test_gru_layer_with_hidden_input(RandomMat(5, 16), 16, 2) - || test_gru_layer_with_hidden_input(RandomMat(3, 16), 8, 2) - || test_gru_layer_with_hidden_input(RandomMat(2, 5), 99, 2) - || test_gru_layer_with_hidden_input(RandomMat(4, 4), 1, 1) - || test_gru_layer_with_hidden_input(RandomMat(8, 2), 2, 1) - || test_gru_layer_with_hidden_input(RandomMat(16, 8), 7, 1) - || test_gru_layer_with_hidden_input(RandomMat(17, 8), 8, 1) - || test_gru_layer_with_hidden_input(RandomMat(19, 15), 8, 1) - || test_gru_layer_with_hidden_input(RandomMat(5, 16), 16, 1) - || test_gru_layer_with_hidden_input(RandomMat(3, 16), 8, 1) - || test_gru_layer_with_hidden_input(RandomMat(2, 5), 99, 1) - || test_gru_layer_with_hidden_input(RandomMat(4, 2), 1, 0) - || test_gru_layer_with_hidden_input(RandomMat(8, 2), 2, 0) - || test_gru_layer_with_hidden_input(RandomMat(16, 8), 7, 0) - || test_gru_layer_with_hidden_input(RandomMat(17, 8), 8, 0) - || test_gru_layer_with_hidden_input(RandomMat(19, 15), 8, 0) - || test_gru_layer_with_hidden_input(RandomMat(5, 16), 16, 0) - || test_gru_layer_with_hidden_input(RandomMat(3, 16), 8, 0) - || test_gru_layer_with_hidden_input(RandomMat(2, 5), 17, 0) - - || test_gru_layer_with_hidden_output(RandomMat(4, 4), 1, 2) - || test_gru_layer_with_hidden_output(RandomMat(8, 2), 2, 2) - || test_gru_layer_with_hidden_output(RandomMat(16, 8), 7, 2) - || test_gru_layer_with_hidden_output(RandomMat(17, 8), 8, 2) - || test_gru_layer_with_hidden_output(RandomMat(19, 15), 8, 2) - || test_gru_layer_with_hidden_output(RandomMat(5, 16), 16, 2) - || test_gru_layer_with_hidden_output(RandomMat(3, 16), 8, 2) - || test_gru_layer_with_hidden_output(RandomMat(2, 5), 99, 2) - || test_gru_layer_with_hidden_output(RandomMat(4, 4), 1, 1) - || test_gru_layer_with_hidden_output(RandomMat(8, 2), 2, 1) - || test_gru_layer_with_hidden_output(RandomMat(16, 8), 7, 1) - || test_gru_layer_with_hidden_output(RandomMat(17, 8), 8, 1) - || test_gru_layer_with_hidden_output(RandomMat(19, 15), 8, 1) - || test_gru_layer_with_hidden_output(RandomMat(5, 16), 16, 1) - || test_gru_layer_with_hidden_output(RandomMat(3, 16), 8, 1) - || test_gru_layer_with_hidden_output(RandomMat(2, 5), 99, 1) - || test_gru_layer_with_hidden_output(RandomMat(4, 2), 1, 0) - || test_gru_layer_with_hidden_output(RandomMat(8, 2), 2, 0) - || test_gru_layer_with_hidden_output(RandomMat(16, 8), 7, 0) - || test_gru_layer_with_hidden_output(RandomMat(17, 8), 8, 0) - || test_gru_layer_with_hidden_output(RandomMat(19, 15), 8, 0) - || test_gru_layer_with_hidden_output(RandomMat(5, 16), 16, 0) - || test_gru_layer_with_hidden_output(RandomMat(3, 16), 8, 0) - || test_gru_layer_with_hidden_output(RandomMat(2, 5), 17, 0); + || test_gru_with_hidden(RandomMat(4, 4), 1, 2) + || test_gru_with_hidden(RandomMat(8, 2), 2, 2) + || test_gru_with_hidden(RandomMat(16, 8), 7, 2) + || test_gru_with_hidden(RandomMat(17, 8), 8, 2) + || test_gru_with_hidden(RandomMat(19, 15), 8, 2) + || test_gru_with_hidden(RandomMat(5, 16), 16, 2) + || test_gru_with_hidden(RandomMat(3, 16), 8, 2) + || test_gru_with_hidden(RandomMat(2, 5), 79, 2) + || test_gru_with_hidden(RandomMat(4, 4), 1, 1) + || test_gru_with_hidden(RandomMat(8, 2), 2, 1) + || test_gru_with_hidden(RandomMat(16, 8), 7, 1) + || test_gru_with_hidden(RandomMat(17, 8), 8, 1) + || test_gru_with_hidden(RandomMat(19, 15), 8, 1) + || test_gru_with_hidden(RandomMat(5, 16), 16, 1) + || test_gru_with_hidden(RandomMat(3, 16), 8, 1) + || test_gru_with_hidden(RandomMat(2, 5), 79, 1) + || test_gru_with_hidden(RandomMat(4, 2), 1, 0) + || test_gru_with_hidden(RandomMat(8, 2), 2, 0) + || test_gru_with_hidden(RandomMat(16, 8), 7, 0) + || test_gru_with_hidden(RandomMat(17, 8), 8, 0) + || test_gru_with_hidden(RandomMat(19, 15), 8, 0) + || test_gru_with_hidden(RandomMat(5, 16), 16, 0) + || test_gru_with_hidden(RandomMat(3, 16), 8, 0) + || test_gru_with_hidden(RandomMat(2, 5), 17, 0) + + || test_gru_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_gru_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_gru_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_gru_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_gru_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_gru_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_gru_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_gru_with_hidden_input(RandomMat(2, 5), 79, 2) + || test_gru_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_gru_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_gru_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_gru_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_gru_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_gru_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_gru_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_gru_with_hidden_input(RandomMat(2, 5), 79, 1) + || test_gru_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_gru_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_gru_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_gru_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_gru_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_gru_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_gru_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_gru_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_gru_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_gru_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_gru_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_gru_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_gru_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_gru_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_gru_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_gru_with_hidden_output(RandomMat(2, 5), 79, 2) + || test_gru_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_gru_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_gru_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_gru_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_gru_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_gru_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_gru_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_gru_with_hidden_output(RandomMat(2, 5), 79, 1) + || test_gru_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_gru_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_gru_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_gru_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_gru_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_gru_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_gru_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_gru_with_hidden_output(RandomMat(2, 5), 17, 0); } static int test_gru_2() @@ -248,8 +249,274 @@ static int test_gru_3() || test_gru(RandomMat(2, 5), 17, 1); } +#if NCNN_INT8 +static int test_gru_int8(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 3 * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * 3 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomS8Mat(outch * outch * 3 * num_directions); + weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + + int ret = test_layer("GRU", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_gru_int8 failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +static int test_gru_int8_with_hidden(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 3 * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * 3 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomS8Mat(outch * outch * 3 * num_directions); + weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("GRU", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_gru_int8_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +static int test_gru_int8_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 3 * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * 3 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomS8Mat(outch * outch * 3 * num_directions); + weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("GRU", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_gru_int8_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +static int test_gru_int8_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * 3 * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * 3 * num_directions); + weights[1] = RandomMat(outch * 4 * num_directions); + weights[2] = RandomS8Mat(outch * outch * 3 * num_directions); + weights[3] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * 3 * num_directions, 100.f, 200.f); + + std::vector as(1); + as[0] = a; + + int ret = test_layer("GRU", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_gru_int8_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +static int test_gru_4() +{ + return 0 + || test_gru_int8(RandomMat(4, 1), 2, 2) + || test_gru_int8(RandomMat(8, 2), 2, 2) + || test_gru_int8(RandomMat(16, 8), 7, 2) + || test_gru_int8(RandomMat(17, 8), 8, 2) + || test_gru_int8(RandomMat(19, 15), 8, 2) + || test_gru_int8(RandomMat(5, 16), 16, 2) + || test_gru_int8(RandomMat(3, 16), 8, 2) + || test_gru_int8(RandomMat(8, 16), 16, 2) + || test_gru_int8(RandomMat(31, 3), 31, 2) + || test_gru_int8(RandomMat(2, 5), 17, 2); +} + +static int test_gru_5() +{ + return 0 + || test_gru_int8_with_hidden(RandomMat(4, 4), 1, 2) + || test_gru_int8_with_hidden(RandomMat(8, 2), 2, 2) + || test_gru_int8_with_hidden(RandomMat(16, 8), 7, 2) + || test_gru_int8_with_hidden(RandomMat(17, 8), 8, 2) + || test_gru_int8_with_hidden(RandomMat(19, 15), 8, 2) + || test_gru_int8_with_hidden(RandomMat(5, 16), 16, 2) + || test_gru_int8_with_hidden(RandomMat(3, 16), 8, 2) + || test_gru_int8_with_hidden(RandomMat(2, 5), 79, 2) + || test_gru_int8_with_hidden(RandomMat(4, 4), 1, 1) + || test_gru_int8_with_hidden(RandomMat(8, 2), 2, 1) + || test_gru_int8_with_hidden(RandomMat(16, 8), 7, 1) + || test_gru_int8_with_hidden(RandomMat(17, 8), 8, 1) + || test_gru_int8_with_hidden(RandomMat(19, 15), 8, 1) + || test_gru_int8_with_hidden(RandomMat(5, 16), 16, 1) + || test_gru_int8_with_hidden(RandomMat(3, 16), 8, 1) + || test_gru_int8_with_hidden(RandomMat(2, 5), 79, 1) + || test_gru_int8_with_hidden(RandomMat(4, 2), 1, 0) + || test_gru_int8_with_hidden(RandomMat(8, 2), 2, 0) + || test_gru_int8_with_hidden(RandomMat(16, 8), 7, 0) + || test_gru_int8_with_hidden(RandomMat(17, 8), 8, 0) + || test_gru_int8_with_hidden(RandomMat(19, 15), 8, 0) + || test_gru_int8_with_hidden(RandomMat(5, 16), 16, 0) + || test_gru_int8_with_hidden(RandomMat(3, 16), 8, 0) + || test_gru_int8_with_hidden(RandomMat(2, 5), 17, 0) + + || test_gru_int8_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_gru_int8_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_gru_int8_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_gru_int8_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_gru_int8_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_gru_int8_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_gru_int8_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_gru_int8_with_hidden_input(RandomMat(2, 5), 79, 2) + || test_gru_int8_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_gru_int8_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_gru_int8_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_gru_int8_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_gru_int8_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_gru_int8_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_gru_int8_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_gru_int8_with_hidden_input(RandomMat(2, 5), 79, 1) + || test_gru_int8_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_gru_int8_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_gru_int8_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_gru_int8_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_gru_int8_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_gru_int8_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_gru_int8_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_gru_int8_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_gru_int8_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_gru_int8_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_gru_int8_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_gru_int8_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_gru_int8_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_gru_int8_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_gru_int8_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_gru_int8_with_hidden_output(RandomMat(2, 5), 79, 2) + || test_gru_int8_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_gru_int8_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_gru_int8_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_gru_int8_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_gru_int8_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_gru_int8_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_gru_int8_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_gru_int8_with_hidden_output(RandomMat(2, 5), 79, 1) + || test_gru_int8_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_gru_int8_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_gru_int8_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_gru_int8_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_gru_int8_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_gru_int8_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_gru_int8_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_gru_int8_with_hidden_output(RandomMat(2, 5), 17, 0); +} + +static int test_gru_6() +{ + return 0 + || test_gru_int8(RandomMat(4, 1), 1, 0) + || test_gru_int8(RandomMat(8, 2), 2, 0) + || test_gru_int8(RandomMat(16, 8), 7, 0) + || test_gru_int8(RandomMat(17, 8), 8, 0) + || test_gru_int8(RandomMat(19, 15), 8, 0) + || test_gru_int8(RandomMat(5, 16), 16, 0) + || test_gru_int8(RandomMat(3, 16), 8, 0) + || test_gru_int8(RandomMat(8, 16), 16, 0) + || test_gru_int8(RandomMat(2, 5), 17, 0); +} + +static int test_gru_7() +{ + return 0 + || test_gru_int8(RandomMat(4, 1), 1, 1) + || test_gru_int8(RandomMat(8, 2), 2, 1) + || test_gru_int8(RandomMat(16, 8), 7, 1) + || test_gru_int8(RandomMat(17, 8), 8, 1) + || test_gru_int8(RandomMat(19, 15), 8, 1) + || test_gru_int8(RandomMat(5, 16), 16, 1) + || test_gru_int8(RandomMat(3, 16), 8, 1) + || test_gru_int8(RandomMat(8, 16), 16, 1) + || test_gru_int8(RandomMat(2, 5), 17, 1); +} +#endif + int main() { SRAND(7767517); - return test_gru_0() || test_gru_1() || test_gru_2() || test_gru_3(); + +#if NCNN_INT8 + return 0 + || test_gru_0() + || test_gru_1() + || test_gru_2() + || test_gru_3() + || test_gru_4() + || test_gru_5() + || test_gru_6() + || test_gru_7(); +#else + return 0 + || test_gru_0() + || test_gru_1() + || test_gru_2() + || test_gru_3(); +#endif } diff --git a/tests/test_lstm.cpp b/tests/test_lstm.cpp index 8b5788a86dc..21a8a4bcfc8 100644 --- a/tests/test_lstm.cpp +++ b/tests/test_lstm.cpp @@ -27,11 +27,11 @@ static int test_lstm(const ncnn::Mat& a, int outch, int direction, int hidden_si pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -45,7 +45,7 @@ static int test_lstm(const ncnn::Mat& a, int outch, int direction, int hidden_si return ret; } -int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +static int test_lstm_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -58,11 +58,11 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, in pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -81,13 +81,13 @@ int test_lstm_layer_with_hidden(const ncnn::Mat& a, int outch, int direction, in int ret = test_layer("LSTM", pd, weights, as, 3); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + fprintf(stderr, "test_lstm_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; } -int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +static int test_lstm_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -100,11 +100,11 @@ int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directi pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -123,13 +123,13 @@ int test_lstm_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directi int ret = test_layer("LSTM", pd, weights, as, 1); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + fprintf(stderr, "test_lstm_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; } -int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +static int test_lstm_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -142,11 +142,11 @@ int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direct pd.set(2, direction); pd.set(3, hidden_size); - std::vector weights(hidden_size == 0 ? 3 : 4); + std::vector weights(hidden_size == outch ? 3 : 4); weights[0] = RandomMat(hidden_size * input_size * 4 * num_directions); weights[1] = RandomMat(hidden_size * 4 * num_directions); weights[2] = RandomMat(outch * hidden_size * 4 * num_directions); - if (hidden_size) + if (hidden_size != outch) { weights[3] = RandomMat(hidden_size * outch * num_directions); } @@ -157,7 +157,7 @@ int test_lstm_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direct int ret = test_layer("LSTM", pd, weights, as, 3); if (ret != 0) { - fprintf(stderr, "test_lstm_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + fprintf(stderr, "test_lstm_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); } return ret; @@ -174,86 +174,87 @@ static int test_lstm_0() || test_lstm(RandomMat(5, 16), 16, 2) || test_lstm(RandomMat(3, 16), 8, 2) || test_lstm(RandomMat(8, 16), 16, 2) + || test_lstm(RandomMat(31, 3), 31, 2) || test_lstm(RandomMat(2, 5), 17, 2, 15); } static int test_lstm_1() { return 0 - || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 2) - || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 2) - || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 2) - || test_lstm_layer_with_hidden(RandomMat(17, 8), 8, 2) - || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 2) - || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 2) - || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 2, 33) - || test_lstm_layer_with_hidden(RandomMat(4, 4), 1, 1) - || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 1) - || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 1) - || test_lstm_layer_with_hidden(RandomMat(17, 8), 8, 1) - || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 1) - || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 1) - || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 99, 1, 33) - || test_lstm_layer_with_hidden(RandomMat(4, 2), 1, 0) - || test_lstm_layer_with_hidden(RandomMat(8, 2), 2, 0) - || test_lstm_layer_with_hidden(RandomMat(16, 8), 7, 0) - || test_lstm_layer_with_hidden(RandomMat(17, 8), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(19, 15), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(5, 16), 16, 0) - || test_lstm_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden(RandomMat(2, 5), 17, 0, 15) - - || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 2) - || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 2) - || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 2) - || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 2) - || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 2) - || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 2) - || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 2, 33) - || test_lstm_layer_with_hidden_input(RandomMat(4, 4), 1, 1) - || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 1) - || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 1) - || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 1) - || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 1) - || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 1) - || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 99, 1, 33) - || test_lstm_layer_with_hidden_input(RandomMat(4, 2), 1, 0) - || test_lstm_layer_with_hidden_input(RandomMat(8, 2), 2, 0) - || test_lstm_layer_with_hidden_input(RandomMat(16, 8), 7, 0) - || test_lstm_layer_with_hidden_input(RandomMat(17, 8), 8, 0) - || test_lstm_layer_with_hidden_input(RandomMat(19, 15), 8, 0) - || test_lstm_layer_with_hidden_input(RandomMat(5, 16), 16, 0) - || test_lstm_layer_with_hidden_input(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden_input(RandomMat(2, 5), 17, 0, 15) - - || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 2) - || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 2) - || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 2) - || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 2) - || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 2) - || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 2) - || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 2) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 2, 33) - || test_lstm_layer_with_hidden_output(RandomMat(4, 4), 1, 1) - || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 1) - || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 1) - || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 1) - || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 1) - || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 1) - || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 1) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 99, 1, 33) - || test_lstm_layer_with_hidden_output(RandomMat(4, 2), 1, 0) - || test_lstm_layer_with_hidden_output(RandomMat(8, 2), 2, 0) - || test_lstm_layer_with_hidden_output(RandomMat(16, 8), 7, 0) - || test_lstm_layer_with_hidden_output(RandomMat(17, 8), 8, 0) - || test_lstm_layer_with_hidden_output(RandomMat(19, 15), 8, 0) - || test_lstm_layer_with_hidden_output(RandomMat(5, 16), 16, 0) - || test_lstm_layer_with_hidden_output(RandomMat(3, 16), 8, 0) - || test_lstm_layer_with_hidden_output(RandomMat(2, 5), 17, 0, 15); + || test_lstm_with_hidden(RandomMat(4, 4), 1, 2) + || test_lstm_with_hidden(RandomMat(8, 2), 2, 2) + || test_lstm_with_hidden(RandomMat(16, 8), 7, 2) + || test_lstm_with_hidden(RandomMat(17, 8), 8, 2) + || test_lstm_with_hidden(RandomMat(19, 15), 8, 2) + || test_lstm_with_hidden(RandomMat(5, 16), 16, 2) + || test_lstm_with_hidden(RandomMat(3, 16), 8, 2) + || test_lstm_with_hidden(RandomMat(2, 5), 79, 2, 33) + || test_lstm_with_hidden(RandomMat(4, 4), 1, 1) + || test_lstm_with_hidden(RandomMat(8, 2), 2, 1) + || test_lstm_with_hidden(RandomMat(16, 8), 7, 1) + || test_lstm_with_hidden(RandomMat(17, 8), 8, 1) + || test_lstm_with_hidden(RandomMat(19, 15), 8, 1) + || test_lstm_with_hidden(RandomMat(5, 16), 16, 1) + || test_lstm_with_hidden(RandomMat(3, 16), 8, 1) + || test_lstm_with_hidden(RandomMat(2, 5), 79, 1, 33) + || test_lstm_with_hidden(RandomMat(4, 2), 1, 0) + || test_lstm_with_hidden(RandomMat(8, 2), 2, 0) + || test_lstm_with_hidden(RandomMat(16, 8), 7, 0) + || test_lstm_with_hidden(RandomMat(17, 8), 8, 0) + || test_lstm_with_hidden(RandomMat(19, 15), 8, 0) + || test_lstm_with_hidden(RandomMat(5, 16), 16, 0) + || test_lstm_with_hidden(RandomMat(3, 16), 8, 0) + || test_lstm_with_hidden(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_lstm_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_lstm_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_lstm_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_lstm_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_lstm_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_lstm_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_lstm_with_hidden_input(RandomMat(2, 5), 79, 2, 33) + || test_lstm_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_lstm_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_lstm_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_lstm_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_lstm_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_lstm_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_lstm_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_lstm_with_hidden_input(RandomMat(2, 5), 79, 1, 33) + || test_lstm_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_lstm_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_lstm_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_lstm_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_lstm_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_lstm_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_lstm_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_lstm_with_hidden_input(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_lstm_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_lstm_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_lstm_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_lstm_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_lstm_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_lstm_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_lstm_with_hidden_output(RandomMat(2, 5), 79, 2, 33) + || test_lstm_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_lstm_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_lstm_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_lstm_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_lstm_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_lstm_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_lstm_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_lstm_with_hidden_output(RandomMat(2, 5), 79, 1, 33) + || test_lstm_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_lstm_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_lstm_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_lstm_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_lstm_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_lstm_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_lstm_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_lstm_with_hidden_output(RandomMat(2, 5), 17, 0, 15); } static int test_lstm_2() @@ -269,6 +270,7 @@ static int test_lstm_2() || test_lstm(RandomMat(8, 16), 16, 0) || test_lstm(RandomMat(2, 5), 17, 0, 15); } + static int test_lstm_3() { return 0 @@ -283,8 +285,330 @@ static int test_lstm_3() || test_lstm(RandomMat(2, 5), 17, 1, 15); } +#if NCNN_INT8 +static int test_lstm_int8(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + int ret = test_layer("LSTM", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8 failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_int8_with_hidden(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + // initial cell state + ncnn::Mat cell = RandomMat(hidden_size, num_directions); + + std::vector as(3); + as[0] = a; + as[1] = hidden; + as[2] = cell; + + int ret = test_layer("LSTM", pd, weights, as, 3); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_int8_with_hidden_input(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions); + + // initial cell state + ncnn::Mat cell = RandomMat(hidden_size, num_directions); + + std::vector as(3); + as[0] = a; + as[1] = hidden; + as[2] = cell; + + int ret = test_layer("LSTM", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_int8_with_hidden_output(const ncnn::Mat& a, int outch, int direction, int hidden_size = 0) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + if (hidden_size == 0) + hidden_size = outch; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, hidden_size * input_size * 4 * num_directions); + pd.set(2, direction); + pd.set(3, hidden_size); + pd.set(8, 2); // int8_scale_term + + std::vector weights(hidden_size == outch ? 5 : 6); + weights[0] = RandomS8Mat(hidden_size * input_size * 4 * num_directions); + weights[1] = RandomMat(hidden_size * 4 * num_directions); + weights[2] = RandomS8Mat(outch * hidden_size * 4 * num_directions); + if (hidden_size != outch) + { + weights[3] = RandomMat(hidden_size * outch * num_directions); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[5] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + else + { + weights[3] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + weights[4] = RandomMat(hidden_size * 4 * num_directions, 100.f, 200.f); + } + + std::vector as(1); + as[0] = a; + + int ret = test_layer("LSTM", pd, weights, as, 3); + if (ret != 0) + { + fprintf(stderr, "test_lstm_int8_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d hidden_size=%d\n", a.dims, a.w, a.h, a.c, outch, direction, hidden_size); + } + + return ret; +} + +static int test_lstm_4() +{ + return 0 + || test_lstm_int8(RandomMat(4, 1), 2, 2) + || test_lstm_int8(RandomMat(8, 2), 2, 2) + || test_lstm_int8(RandomMat(16, 8), 7, 2) + || test_lstm_int8(RandomMat(17, 8), 8, 2) + || test_lstm_int8(RandomMat(19, 15), 8, 2) + || test_lstm_int8(RandomMat(5, 16), 16, 2) + || test_lstm_int8(RandomMat(3, 16), 8, 2) + || test_lstm_int8(RandomMat(8, 16), 16, 2) + || test_lstm_int8(RandomMat(31, 3), 31, 2) + || test_lstm_int8(RandomMat(2, 5), 17, 2, 15); +} + +static int test_lstm_5() +{ + return 0 + || test_lstm_int8_with_hidden(RandomMat(4, 4), 1, 2) + || test_lstm_int8_with_hidden(RandomMat(8, 2), 2, 2) + || test_lstm_int8_with_hidden(RandomMat(16, 8), 7, 2) + || test_lstm_int8_with_hidden(RandomMat(17, 8), 8, 2) + || test_lstm_int8_with_hidden(RandomMat(19, 15), 8, 2) + || test_lstm_int8_with_hidden(RandomMat(5, 16), 16, 2) + || test_lstm_int8_with_hidden(RandomMat(3, 16), 8, 2) + || test_lstm_int8_with_hidden(RandomMat(2, 5), 79, 2, 33) + || test_lstm_int8_with_hidden(RandomMat(4, 4), 1, 1) + || test_lstm_int8_with_hidden(RandomMat(8, 2), 2, 1) + || test_lstm_int8_with_hidden(RandomMat(16, 8), 7, 1) + || test_lstm_int8_with_hidden(RandomMat(17, 8), 8, 1) + || test_lstm_int8_with_hidden(RandomMat(19, 15), 8, 1) + || test_lstm_int8_with_hidden(RandomMat(5, 16), 16, 1) + || test_lstm_int8_with_hidden(RandomMat(3, 16), 8, 1) + || test_lstm_int8_with_hidden(RandomMat(2, 5), 79, 1, 33) + || test_lstm_int8_with_hidden(RandomMat(4, 2), 1, 0) + || test_lstm_int8_with_hidden(RandomMat(8, 2), 2, 0) + || test_lstm_int8_with_hidden(RandomMat(16, 8), 7, 0) + || test_lstm_int8_with_hidden(RandomMat(17, 8), 8, 0) + || test_lstm_int8_with_hidden(RandomMat(19, 15), 8, 0) + || test_lstm_int8_with_hidden(RandomMat(5, 16), 16, 0) + || test_lstm_int8_with_hidden(RandomMat(3, 16), 8, 0) + || test_lstm_int8_with_hidden(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_int8_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_lstm_int8_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_lstm_int8_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_lstm_int8_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_lstm_int8_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_lstm_int8_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_lstm_int8_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_lstm_int8_with_hidden_input(RandomMat(2, 5), 79, 2, 33) + || test_lstm_int8_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_lstm_int8_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_lstm_int8_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_lstm_int8_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_lstm_int8_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_lstm_int8_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_lstm_int8_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_lstm_int8_with_hidden_input(RandomMat(2, 5), 79, 1, 33) + || test_lstm_int8_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_lstm_int8_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_lstm_int8_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_lstm_int8_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_lstm_int8_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_lstm_int8_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_lstm_int8_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_lstm_int8_with_hidden_input(RandomMat(2, 5), 17, 0, 15) + + || test_lstm_int8_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_lstm_int8_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_lstm_int8_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_lstm_int8_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_lstm_int8_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_lstm_int8_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_lstm_int8_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_lstm_int8_with_hidden_output(RandomMat(2, 5), 79, 2, 33) + || test_lstm_int8_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_lstm_int8_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_lstm_int8_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_lstm_int8_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_lstm_int8_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_lstm_int8_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_lstm_int8_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_lstm_int8_with_hidden_output(RandomMat(2, 5), 79, 1, 33) + || test_lstm_int8_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_lstm_int8_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_lstm_int8_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_lstm_int8_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_lstm_int8_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_lstm_int8_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_lstm_int8_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_lstm_int8_with_hidden_output(RandomMat(2, 5), 17, 0, 15); +} + +static int test_lstm_6() +{ + return 0 + || test_lstm_int8(RandomMat(4, 1), 1, 0) + || test_lstm_int8(RandomMat(8, 2), 2, 0) + || test_lstm_int8(RandomMat(16, 8), 7, 0) + || test_lstm_int8(RandomMat(17, 8), 8, 0) + || test_lstm_int8(RandomMat(19, 15), 8, 0) + || test_lstm_int8(RandomMat(5, 16), 16, 0) + || test_lstm_int8(RandomMat(3, 16), 8, 0) + || test_lstm_int8(RandomMat(8, 16), 16, 0) + || test_lstm_int8(RandomMat(2, 5), 17, 0, 15); +} + +static int test_lstm_7() +{ + return 0 + || test_lstm_int8(RandomMat(4, 1), 1, 1) + || test_lstm_int8(RandomMat(8, 2), 2, 1) + || test_lstm_int8(RandomMat(16, 8), 7, 1) + || test_lstm_int8(RandomMat(17, 8), 8, 1) + || test_lstm_int8(RandomMat(19, 15), 8, 1) + || test_lstm_int8(RandomMat(5, 16), 16, 1) + || test_lstm_int8(RandomMat(3, 16), 8, 1) + || test_lstm_int8(RandomMat(8, 16), 16, 1) + || test_lstm_int8(RandomMat(2, 5), 17, 1, 15); +} +#endif + int main() { SRAND(7767517); - return 0 || test_lstm_0() || test_lstm_1() || test_lstm_2() || test_lstm_3(); + +#if NCNN_INT8 + return 0 + || test_lstm_0() + || test_lstm_1() + || test_lstm_2() + || test_lstm_3() + || test_lstm_4() + || test_lstm_5() + || test_lstm_6() + || test_lstm_7(); +#else + return 0 + || test_lstm_0() + || test_lstm_1() + || test_lstm_2() + || test_lstm_3(); +#endif } diff --git a/tests/test_rnn.cpp b/tests/test_rnn.cpp index f9cb9a5d752..c17b23b5c5b 100644 --- a/tests/test_rnn.cpp +++ b/tests/test_rnn.cpp @@ -32,13 +32,13 @@ static int test_rnn(const ncnn::Mat& a, int outch, int direction) int ret = test_layer("RNN", pd, weights, a); if (ret != 0) { - fprintf(stderr, "test_rnn failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_rnn_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) +int test_rnn_with_hidden(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -54,7 +54,7 @@ int test_rnn_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) weights[2] = RandomMat(outch * outch * num_directions); // initial hidden state - ncnn::Mat hidden = RandomMat(outch, num_directions); + ncnn::Mat hidden = RandomMat(outch, num_directions, -1.f, 1.f); std::vector as(2); as[0] = a; @@ -63,13 +63,13 @@ int test_rnn_layer_with_hidden(const ncnn::Mat& a, int outch, int direction) int ret = test_layer("RNN", pd, weights, as, 2); if (ret != 0) { - fprintf(stderr, "test_rnn_layer_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_rnn_layer_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +int test_rnn_with_hidden_input(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -85,7 +85,7 @@ int test_rnn_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directio weights[2] = RandomMat(outch * outch * num_directions); // initial hidden state - ncnn::Mat hidden = RandomMat(outch, num_directions); + ncnn::Mat hidden = RandomMat(outch, num_directions, -1.f, 1.f); std::vector as(2); as[0] = a; @@ -94,13 +94,13 @@ int test_rnn_layer_with_hidden_input(const ncnn::Mat& a, int outch, int directio int ret = test_layer("RNN", pd, weights, as, 1); if (ret != 0) { - fprintf(stderr, "test_rnn_layer_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; } -int test_rnn_layer_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +int test_rnn_with_hidden_output(const ncnn::Mat& a, int outch, int direction) { int input_size = a.w; int num_directions = direction == 2 ? 2 : 1; @@ -121,7 +121,7 @@ int test_rnn_layer_with_hidden_output(const ncnn::Mat& a, int outch, int directi int ret = test_layer("RNN", pd, weights, as, 2); if (ret != 0) { - fprintf(stderr, "test_rnn_layer_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d, direction = %d \n", a.dims, a.w, a.h, a.c, outch, direction); + fprintf(stderr, "test_rnn_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); } return ret; @@ -138,86 +138,87 @@ static int test_rnn_0() || test_rnn(RandomMat(5, 16), 16, 2) || test_rnn(RandomMat(3, 16), 8, 2) || test_rnn(RandomMat(8, 16), 16, 2) + || test_rnn(RandomMat(31, 3), 31, 2) || test_rnn(RandomMat(2, 5), 17, 2); } static int test_rnn_1() { return 0 - || test_rnn_layer_with_hidden(RandomMat(4, 4), 1, 2) - || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 2) - || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 2) - || test_rnn_layer_with_hidden(RandomMat(17, 8), 8, 2) - || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 2) - || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 2) - || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 2) - || test_rnn_layer_with_hidden(RandomMat(2, 5), 99, 2) - || test_rnn_layer_with_hidden(RandomMat(4, 4), 1, 1) - || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 1) - || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 1) - || test_rnn_layer_with_hidden(RandomMat(17, 8), 8, 1) - || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 1) - || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 1) - || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 1) - || test_rnn_layer_with_hidden(RandomMat(2, 5), 99, 1) - || test_rnn_layer_with_hidden(RandomMat(4, 2), 1, 0) - || test_rnn_layer_with_hidden(RandomMat(8, 2), 2, 0) - || test_rnn_layer_with_hidden(RandomMat(16, 8), 7, 0) - || test_rnn_layer_with_hidden(RandomMat(17, 8), 8, 0) - || test_rnn_layer_with_hidden(RandomMat(19, 15), 8, 0) - || test_rnn_layer_with_hidden(RandomMat(5, 16), 16, 0) - || test_rnn_layer_with_hidden(RandomMat(3, 16), 8, 0) - || test_rnn_layer_with_hidden(RandomMat(2, 5), 17, 0) - - || test_rnn_layer_with_hidden_input(RandomMat(4, 4), 1, 2) - || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 2) - || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 2) - || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 2) - || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 2) - || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 2) - || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 2) - || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 99, 2) - || test_rnn_layer_with_hidden_input(RandomMat(4, 4), 1, 1) - || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 1) - || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 1) - || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 1) - || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 1) - || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 1) - || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 1) - || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 99, 1) - || test_rnn_layer_with_hidden_input(RandomMat(4, 2), 1, 0) - || test_rnn_layer_with_hidden_input(RandomMat(8, 2), 2, 0) - || test_rnn_layer_with_hidden_input(RandomMat(16, 8), 7, 0) - || test_rnn_layer_with_hidden_input(RandomMat(17, 8), 8, 0) - || test_rnn_layer_with_hidden_input(RandomMat(19, 15), 8, 0) - || test_rnn_layer_with_hidden_input(RandomMat(5, 16), 16, 0) - || test_rnn_layer_with_hidden_input(RandomMat(3, 16), 8, 0) - || test_rnn_layer_with_hidden_input(RandomMat(2, 5), 17, 0) - - || test_rnn_layer_with_hidden_output(RandomMat(4, 4), 1, 2) - || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 2) - || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 2) - || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 2) - || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 2) - || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 2) - || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 2) - || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 99, 2) - || test_rnn_layer_with_hidden_output(RandomMat(4, 4), 1, 1) - || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 1) - || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 1) - || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 1) - || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 1) - || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 1) - || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 1) - || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 99, 1) - || test_rnn_layer_with_hidden_output(RandomMat(4, 2), 1, 0) - || test_rnn_layer_with_hidden_output(RandomMat(8, 2), 2, 0) - || test_rnn_layer_with_hidden_output(RandomMat(16, 8), 7, 0) - || test_rnn_layer_with_hidden_output(RandomMat(17, 8), 8, 0) - || test_rnn_layer_with_hidden_output(RandomMat(19, 15), 8, 0) - || test_rnn_layer_with_hidden_output(RandomMat(5, 16), 16, 0) - || test_rnn_layer_with_hidden_output(RandomMat(3, 16), 8, 0) - || test_rnn_layer_with_hidden_output(RandomMat(2, 5), 17, 0); + || test_rnn_with_hidden(RandomMat(4, 4), 1, 2) + || test_rnn_with_hidden(RandomMat(8, 2), 2, 2) + || test_rnn_with_hidden(RandomMat(16, 8), 7, 2) + || test_rnn_with_hidden(RandomMat(17, 8), 8, 2) + || test_rnn_with_hidden(RandomMat(19, 15), 8, 2) + || test_rnn_with_hidden(RandomMat(5, 16), 16, 2) + || test_rnn_with_hidden(RandomMat(3, 16), 8, 2) + || test_rnn_with_hidden(RandomMat(2, 5), 79, 2) + || test_rnn_with_hidden(RandomMat(4, 4), 1, 1) + || test_rnn_with_hidden(RandomMat(8, 2), 2, 1) + || test_rnn_with_hidden(RandomMat(16, 8), 7, 1) + || test_rnn_with_hidden(RandomMat(17, 8), 8, 1) + || test_rnn_with_hidden(RandomMat(19, 15), 8, 1) + || test_rnn_with_hidden(RandomMat(5, 16), 16, 1) + || test_rnn_with_hidden(RandomMat(3, 16), 8, 1) + || test_rnn_with_hidden(RandomMat(2, 5), 79, 1) + || test_rnn_with_hidden(RandomMat(4, 2), 1, 0) + || test_rnn_with_hidden(RandomMat(8, 2), 2, 0) + || test_rnn_with_hidden(RandomMat(16, 8), 7, 0) + || test_rnn_with_hidden(RandomMat(17, 8), 8, 0) + || test_rnn_with_hidden(RandomMat(19, 15), 8, 0) + || test_rnn_with_hidden(RandomMat(5, 16), 16, 0) + || test_rnn_with_hidden(RandomMat(3, 16), 8, 0) + || test_rnn_with_hidden(RandomMat(2, 5), 17, 0) + + || test_rnn_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_rnn_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_rnn_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_rnn_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_rnn_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_rnn_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_rnn_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_rnn_with_hidden_input(RandomMat(2, 5), 79, 2) + || test_rnn_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_rnn_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_rnn_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_rnn_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_rnn_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_rnn_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_rnn_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_rnn_with_hidden_input(RandomMat(2, 5), 79, 1) + || test_rnn_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_rnn_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_rnn_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_rnn_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_rnn_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_rnn_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_rnn_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_rnn_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_rnn_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_rnn_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_rnn_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_rnn_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_rnn_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_rnn_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_rnn_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_rnn_with_hidden_output(RandomMat(2, 5), 79, 2) + || test_rnn_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_rnn_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_rnn_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_rnn_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_rnn_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_rnn_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_rnn_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_rnn_with_hidden_output(RandomMat(2, 5), 79, 1) + || test_rnn_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_rnn_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_rnn_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_rnn_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_rnn_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_rnn_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_rnn_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_rnn_with_hidden_output(RandomMat(2, 5), 17, 0); } static int test_rnn_2() @@ -248,8 +249,274 @@ static int test_rnn_3() || test_rnn(RandomMat(2, 5), 17, 1); } +#if NCNN_INT8 +static int test_rnn_int8(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + int ret = test_layer("RNN", pd, weights, a); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8 failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_rnn_int8_with_hidden(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions, -1.f, 1.f); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("RNN", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8_with_hidden failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_rnn_int8_with_hidden_input(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + // initial hidden state + ncnn::Mat hidden = RandomMat(outch, num_directions, -1.f, 1.f); + + std::vector as(2); + as[0] = a; + as[1] = hidden; + + int ret = test_layer("RNN", pd, weights, as, 1); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8_with_hidden_input failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +int test_rnn_int8_with_hidden_output(const ncnn::Mat& a, int outch, int direction) +{ + int input_size = a.w; + int num_directions = direction == 2 ? 2 : 1; + + ncnn::ParamDict pd; + pd.set(0, outch); + pd.set(1, outch * input_size * num_directions); + pd.set(2, direction); + pd.set(8, 2); // int8_scale_term + + std::vector weights(5); + weights[0] = RandomS8Mat(outch * input_size * num_directions); + weights[1] = RandomMat(outch * num_directions); + weights[2] = RandomS8Mat(outch * outch * num_directions); + weights[3] = RandomMat(outch * num_directions, 100.f, 200.f); + weights[4] = RandomMat(outch * num_directions, 100.f, 200.f); + + std::vector as(1); + as[0] = a; + + int ret = test_layer("RNN", pd, weights, as, 2); + if (ret != 0) + { + fprintf(stderr, "test_rnn_int8_with_hidden_output failed a.dims=%d a=(%d %d %d) outch=%d direction=%d\n", a.dims, a.w, a.h, a.c, outch, direction); + } + + return ret; +} + +static int test_rnn_4() +{ + return 0 + || test_rnn_int8(RandomMat(4, 1), 2, 2) + || test_rnn_int8(RandomMat(8, 2), 2, 2) + || test_rnn_int8(RandomMat(16, 8), 7, 2) + || test_rnn_int8(RandomMat(17, 8), 8, 2) + || test_rnn_int8(RandomMat(19, 15), 8, 2) + || test_rnn_int8(RandomMat(5, 16), 16, 2) + || test_rnn_int8(RandomMat(3, 16), 8, 2) + || test_rnn_int8(RandomMat(8, 16), 16, 2) + || test_rnn_int8(RandomMat(31, 3), 31, 2) + || test_rnn_int8(RandomMat(2, 5), 17, 2); +} + +static int test_rnn_5() +{ + return 0 + || test_rnn_int8_with_hidden(RandomMat(4, 4), 1, 2) + || test_rnn_int8_with_hidden(RandomMat(8, 2), 2, 2) + || test_rnn_int8_with_hidden(RandomMat(16, 8), 7, 2) + || test_rnn_int8_with_hidden(RandomMat(17, 8), 8, 2) + || test_rnn_int8_with_hidden(RandomMat(19, 15), 8, 2) + || test_rnn_int8_with_hidden(RandomMat(5, 16), 16, 2) + || test_rnn_int8_with_hidden(RandomMat(3, 16), 8, 2) + || test_rnn_int8_with_hidden(RandomMat(2, 5), 79, 2) + || test_rnn_int8_with_hidden(RandomMat(4, 4), 1, 1) + || test_rnn_int8_with_hidden(RandomMat(8, 2), 2, 1) + || test_rnn_int8_with_hidden(RandomMat(16, 8), 7, 1) + || test_rnn_int8_with_hidden(RandomMat(17, 8), 8, 1) + || test_rnn_int8_with_hidden(RandomMat(19, 15), 8, 1) + || test_rnn_int8_with_hidden(RandomMat(5, 16), 16, 1) + || test_rnn_int8_with_hidden(RandomMat(3, 16), 8, 1) + || test_rnn_int8_with_hidden(RandomMat(2, 5), 79, 1) + || test_rnn_int8_with_hidden(RandomMat(4, 2), 1, 0) + || test_rnn_int8_with_hidden(RandomMat(8, 2), 2, 0) + || test_rnn_int8_with_hidden(RandomMat(16, 8), 7, 0) + || test_rnn_int8_with_hidden(RandomMat(17, 8), 8, 0) + || test_rnn_int8_with_hidden(RandomMat(19, 15), 8, 0) + || test_rnn_int8_with_hidden(RandomMat(5, 16), 16, 0) + || test_rnn_int8_with_hidden(RandomMat(3, 16), 8, 0) + || test_rnn_int8_with_hidden(RandomMat(2, 5), 17, 0) + + || test_rnn_int8_with_hidden_input(RandomMat(4, 4), 1, 2) + || test_rnn_int8_with_hidden_input(RandomMat(8, 2), 2, 2) + || test_rnn_int8_with_hidden_input(RandomMat(16, 8), 7, 2) + || test_rnn_int8_with_hidden_input(RandomMat(17, 8), 8, 2) + || test_rnn_int8_with_hidden_input(RandomMat(19, 15), 8, 2) + || test_rnn_int8_with_hidden_input(RandomMat(5, 16), 16, 2) + || test_rnn_int8_with_hidden_input(RandomMat(3, 16), 8, 2) + || test_rnn_int8_with_hidden_input(RandomMat(2, 5), 79, 2) + || test_rnn_int8_with_hidden_input(RandomMat(4, 4), 1, 1) + || test_rnn_int8_with_hidden_input(RandomMat(8, 2), 2, 1) + || test_rnn_int8_with_hidden_input(RandomMat(16, 8), 7, 1) + || test_rnn_int8_with_hidden_input(RandomMat(17, 8), 8, 1) + || test_rnn_int8_with_hidden_input(RandomMat(19, 15), 8, 1) + || test_rnn_int8_with_hidden_input(RandomMat(5, 16), 16, 1) + || test_rnn_int8_with_hidden_input(RandomMat(3, 16), 8, 1) + || test_rnn_int8_with_hidden_input(RandomMat(2, 5), 79, 1) + || test_rnn_int8_with_hidden_input(RandomMat(4, 2), 1, 0) + || test_rnn_int8_with_hidden_input(RandomMat(8, 2), 2, 0) + || test_rnn_int8_with_hidden_input(RandomMat(16, 8), 7, 0) + || test_rnn_int8_with_hidden_input(RandomMat(17, 8), 8, 0) + || test_rnn_int8_with_hidden_input(RandomMat(19, 15), 8, 0) + || test_rnn_int8_with_hidden_input(RandomMat(5, 16), 16, 0) + || test_rnn_int8_with_hidden_input(RandomMat(3, 16), 8, 0) + || test_rnn_int8_with_hidden_input(RandomMat(2, 5), 17, 0) + + || test_rnn_int8_with_hidden_output(RandomMat(4, 4), 1, 2) + || test_rnn_int8_with_hidden_output(RandomMat(8, 2), 2, 2) + || test_rnn_int8_with_hidden_output(RandomMat(16, 8), 7, 2) + || test_rnn_int8_with_hidden_output(RandomMat(17, 8), 8, 2) + || test_rnn_int8_with_hidden_output(RandomMat(19, 15), 8, 2) + || test_rnn_int8_with_hidden_output(RandomMat(5, 16), 16, 2) + || test_rnn_int8_with_hidden_output(RandomMat(3, 16), 8, 2) + || test_rnn_int8_with_hidden_output(RandomMat(2, 5), 79, 2) + || test_rnn_int8_with_hidden_output(RandomMat(4, 4), 1, 1) + || test_rnn_int8_with_hidden_output(RandomMat(8, 2), 2, 1) + || test_rnn_int8_with_hidden_output(RandomMat(16, 8), 7, 1) + || test_rnn_int8_with_hidden_output(RandomMat(17, 8), 8, 1) + || test_rnn_int8_with_hidden_output(RandomMat(19, 15), 8, 1) + || test_rnn_int8_with_hidden_output(RandomMat(5, 16), 16, 1) + || test_rnn_int8_with_hidden_output(RandomMat(3, 16), 8, 1) + || test_rnn_int8_with_hidden_output(RandomMat(2, 5), 79, 1) + || test_rnn_int8_with_hidden_output(RandomMat(4, 2), 1, 0) + || test_rnn_int8_with_hidden_output(RandomMat(8, 2), 2, 0) + || test_rnn_int8_with_hidden_output(RandomMat(16, 8), 7, 0) + || test_rnn_int8_with_hidden_output(RandomMat(17, 8), 8, 0) + || test_rnn_int8_with_hidden_output(RandomMat(19, 15), 8, 0) + || test_rnn_int8_with_hidden_output(RandomMat(5, 16), 16, 0) + || test_rnn_int8_with_hidden_output(RandomMat(3, 16), 8, 0) + || test_rnn_int8_with_hidden_output(RandomMat(2, 5), 17, 0); +} + +static int test_rnn_6() +{ + return 0 + || test_rnn_int8(RandomMat(4, 1), 1, 0) + || test_rnn_int8(RandomMat(8, 2), 2, 0) + || test_rnn_int8(RandomMat(16, 8), 7, 0) + || test_rnn_int8(RandomMat(17, 8), 8, 0) + || test_rnn_int8(RandomMat(19, 15), 8, 0) + || test_rnn_int8(RandomMat(5, 16), 16, 0) + || test_rnn_int8(RandomMat(3, 16), 8, 0) + || test_rnn_int8(RandomMat(8, 16), 16, 0) + || test_rnn_int8(RandomMat(2, 5), 17, 0); +} + +static int test_rnn_7() +{ + return 0 + || test_rnn_int8(RandomMat(4, 1), 1, 1) + || test_rnn_int8(RandomMat(8, 2), 2, 1) + || test_rnn_int8(RandomMat(16, 8), 7, 1) + || test_rnn_int8(RandomMat(17, 8), 8, 1) + || test_rnn_int8(RandomMat(19, 15), 8, 1) + || test_rnn_int8(RandomMat(5, 16), 16, 1) + || test_rnn_int8(RandomMat(3, 16), 8, 1) + || test_rnn_int8(RandomMat(8, 16), 16, 1) + || test_rnn_int8(RandomMat(2, 5), 17, 1); +} +#endif + int main() { SRAND(7767517); - return test_rnn_0() || test_rnn_1() || test_rnn_2() || test_rnn_3(); + +#if NCNN_INT8 + return 0 + || test_rnn_0() + || test_rnn_1() + || test_rnn_2() + || test_rnn_3() + || test_rnn_4() + || test_rnn_5() + || test_rnn_6() + || test_rnn_7(); +#else + return 0 + || test_rnn_0() + || test_rnn_1() + || test_rnn_2() + || test_rnn_3(); +#endif } diff --git a/tests/testutil.cpp b/tests/testutil.cpp index f0bf3c51a20..2e76f6f3901 100644 --- a/tests/testutil.cpp +++ b/tests/testutil.cpp @@ -1380,6 +1380,12 @@ int test_layer_opt(const char* layer_type, const ncnn::ParamDict& pd, const std: weights_fp16.resize(weights.size()); for (size_t j = 0; j < weights.size(); j++) { + if (weights[j].elembits() != 32) + { + weights_fp16[j] = weights[j]; + continue; + } + ncnn::Mat tmp; ncnn::cast_float32_to_bfloat16(weights[j], tmp, opt); ncnn::cast_bfloat16_to_float32(tmp, weights_fp16[j], opt); @@ -1391,6 +1397,12 @@ int test_layer_opt(const char* layer_type, const ncnn::ParamDict& pd, const std: weights_fp16.resize(weights.size()); for (size_t j = 0; j < weights.size(); j++) { + if (weights[j].elembits() != 32) + { + weights_fp16[j] = weights[j]; + continue; + } + ncnn::Mat tmp; ncnn::cast_float32_to_float16(weights[j], tmp, opt); ncnn::cast_float16_to_float32(tmp, weights_fp16[j], opt); @@ -1447,6 +1459,12 @@ int test_layer_opt(const char* layer_type, const ncnn::ParamDict& pd, const std: weights_fp16.resize(weights.size()); for (size_t j = 0; j < weights.size(); j++) { + if (weights[j].elembits() != 32) + { + weights_fp16[j] = weights[j]; + continue; + } + ncnn::Mat tmp; ncnn::cast_float32_to_bfloat16(weights[j], tmp, opt); ncnn::cast_bfloat16_to_float32(tmp, weights_fp16[j], opt); @@ -1458,6 +1476,12 @@ int test_layer_opt(const char* layer_type, const ncnn::ParamDict& pd, const std: weights_fp16.resize(weights.size()); for (size_t j = 0; j < weights.size(); j++) { + if (weights[j].elembits() != 32) + { + weights_fp16[j] = weights[j]; + continue; + } + ncnn::Mat tmp; ncnn::cast_float32_to_float16(weights[j], tmp, opt); ncnn::cast_float16_to_float32(tmp, weights_fp16[j], opt); diff --git a/tools/modelwriter.h b/tools/modelwriter.h index 8707fe9eb81..88ccb948a9c 100644 --- a/tools/modelwriter.h +++ b/tools/modelwriter.h @@ -1816,10 +1816,20 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 0=%d", num_output) fprintf_param_value(" 1=%d", weight_data_size) fprintf_param_value(" 2=%d", direction) + fprintf_param_value(" 8=%d", int8_scale_term) fwrite_weight_tag_data(op->weight_xc_data, bp); fwrite_weight_tag_data(op->bias_c_data, bp); fwrite_weight_tag_data(op->weight_hc_data, bp); + +#if NCNN_INT8 + // write int8_scale data + if (op->int8_scale_term) + { + fwrite_weight_data(op->weight_xc_data_int8_scales, bp, 90, 100); + fwrite_weight_data(op->weight_hc_data_int8_scales, bp, 90, 100); + } +#endif // NCNN_INT8 } else if (layer->type == "HardSigmoid") { @@ -1948,6 +1958,7 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 1=%d", weight_data_size) fprintf_param_value(" 2=%d", direction) fprintf_param_value(" 3=%d", hidden_size) + fprintf_param_value(" 8=%d", int8_scale_term) fwrite_weight_tag_data(op->weight_xc_data, bp); fwrite_weight_tag_data(op->bias_c_data, bp); @@ -1957,6 +1968,15 @@ int ModelWriter::save(const char* parampath, const char* binpath) { fwrite_weight_tag_data(op->weight_hr_data, bp); } + +#if NCNN_INT8 + // write int8_scale data + if (op->int8_scale_term) + { + fwrite_weight_data(op->weight_xc_data_int8_scales, bp, 90, 100); + fwrite_weight_data(op->weight_hc_data_int8_scales, bp, 90, 100); + } +#endif // NCNN_INT8 } else if (layer->type == "MatMul") { @@ -2289,10 +2309,20 @@ int ModelWriter::save(const char* parampath, const char* binpath) fprintf_param_value(" 0=%d", num_output) fprintf_param_value(" 1=%d", weight_data_size) fprintf_param_value(" 2=%d", direction) + fprintf_param_value(" 8=%d", int8_scale_term) fwrite_weight_tag_data(op->weight_xc_data, bp); fwrite_weight_tag_data(op->bias_c_data, bp); fwrite_weight_tag_data(op->weight_hc_data, bp); + +#if NCNN_INT8 + // write int8_scale data + if (op->int8_scale_term) + { + fwrite_weight_data(op->weight_xc_data_int8_scales, bp, 90, 100); + fwrite_weight_data(op->weight_hc_data_int8_scales, bp, 90, 100); + } +#endif // NCNN_INT8 } else if (layer->type == "ROIAlign") { diff --git a/tools/quantize/ncnn2int8.cpp b/tools/quantize/ncnn2int8.cpp index f712306b022..4d19ceb6f16 100644 --- a/tools/quantize/ncnn2int8.cpp +++ b/tools/quantize/ncnn2int8.cpp @@ -129,6 +129,10 @@ class NetQuantize : public ModelWriter int quantize_convolutiondepthwise(); int quantize_innerproduct(); + int quantize_rnn(); + int quantize_lstm(); + int quantize_gru(); + int fuse_requantize(); }; @@ -312,6 +316,252 @@ int NetQuantize::quantize_innerproduct() return 0; } +int NetQuantize::quantize_rnn() +{ + for (size_t i = 0; i < layers.size(); i++) + { + if (layers[i]->type != "RNN") + continue; + + // RNN - quantize weight from fp32 to int8 + ncnn::RNN* rnn = (ncnn::RNN*)layers[i]; + + fprintf(stderr, "quantize_rnn %s\n", rnn->name.c_str()); + + // TODO move to ncnn2table + const int num_directions = rnn->direction == 2 ? 2 : 1; + const int size = rnn->weight_data_size / num_directions / rnn->num_output; + + ncnn::Mat weight_xc_data_int8_scales(rnn->num_output * num_directions); + ncnn::Mat weight_hc_data_int8_scales(rnn->num_output * num_directions); + + for (int d = 0; d < num_directions; d++) + { + for (int q = 0; q < rnn->num_output; q++) + { + { + const float* weight_xc_ptr = rnn->weight_xc_data.channel(d).row(q); + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(weight_xc_ptr[i])); + } + weight_xc_data_int8_scales[d * rnn->num_output + q] = 127 / absmax; + } + + { + const float* weight_hc_ptr = rnn->weight_hc_data.channel(d).row(q); + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(weight_hc_ptr[i])); + } + weight_hc_data_int8_scales[d * rnn->num_output + q] = 127 / absmax; + } + } + } + + { + ncnn::Mat weight_xc_data_r2 = rnn->weight_xc_data.reshape(size, rnn->num_output * num_directions); + + ncnn::Mat weight_xc_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = rnn->weight_xc_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(weight_xc_data_r2, weight_xc_data_int8, weight_xc_data_int8_scales, opt_q); + if (weight_xc_data_int8.empty()) + return -100; + + rnn->weight_xc_data = weight_xc_data_int8.reshape(size * rnn->num_output * num_directions); + } + { + ncnn::Mat weight_hc_data_r2 = rnn->weight_hc_data.reshape(rnn->num_output, rnn->num_output * num_directions); + + ncnn::Mat weight_hc_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = rnn->weight_hc_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(weight_hc_data_r2, weight_hc_data_int8, weight_hc_data_int8_scales, opt_q); + if (weight_hc_data_int8.empty()) + return -100; + + rnn->weight_hc_data = weight_hc_data_int8.reshape(rnn->num_output * rnn->num_output * num_directions); + } + + rnn->int8_scale_term = 2; + rnn->weight_xc_data_int8_scales = weight_xc_data_int8_scales; + rnn->weight_hc_data_int8_scales = weight_hc_data_int8_scales; + } + + return 0; +} + +int NetQuantize::quantize_lstm() +{ + for (size_t i = 0; i < layers.size(); i++) + { + if (layers[i]->type != "LSTM") + continue; + + // LSTM - quantize weight from fp32 to int8 + ncnn::LSTM* lstm = (ncnn::LSTM*)layers[i]; + + fprintf(stderr, "quantize_lstm %s\n", lstm->name.c_str()); + + // TODO move to ncnn2table + const int num_directions = lstm->direction == 2 ? 2 : 1; + const int size = lstm->weight_data_size / num_directions / lstm->hidden_size / 4; + + ncnn::Mat weight_xc_data_int8_scales(lstm->hidden_size * 4 * num_directions); + ncnn::Mat weight_hc_data_int8_scales(lstm->hidden_size * 4 * num_directions); + + for (int d = 0; d < num_directions; d++) + { + for (int q = 0; q < lstm->hidden_size * 4; q++) + { + { + const float* weight_xc_ptr = lstm->weight_xc_data.channel(d).row(q); + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(weight_xc_ptr[i])); + } + weight_xc_data_int8_scales[d * lstm->hidden_size * 4 + q] = 127 / absmax; + } + + { + const float* weight_hc_ptr = lstm->weight_hc_data.channel(d).row(q); + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(weight_hc_ptr[i])); + } + weight_hc_data_int8_scales[d * lstm->hidden_size * 4 + q] = 127 / absmax; + } + } + } + + { + ncnn::Mat weight_xc_data_r2 = lstm->weight_xc_data.reshape(size, lstm->hidden_size * 4 * num_directions); + + ncnn::Mat weight_xc_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = lstm->weight_xc_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(weight_xc_data_r2, weight_xc_data_int8, weight_xc_data_int8_scales, opt_q); + if (weight_xc_data_int8.empty()) + return -100; + + lstm->weight_xc_data = weight_xc_data_int8.reshape(size * lstm->hidden_size * 4 * num_directions); + } + { + ncnn::Mat weight_hc_data_r2 = lstm->weight_hc_data.reshape(lstm->num_output, lstm->hidden_size * 4 * num_directions); + + ncnn::Mat weight_hc_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = lstm->weight_hc_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(weight_hc_data_r2, weight_hc_data_int8, weight_hc_data_int8_scales, opt_q); + if (weight_hc_data_int8.empty()) + return -100; + + lstm->weight_hc_data = weight_hc_data_int8.reshape(lstm->num_output * lstm->hidden_size * 4 * num_directions); + } + + lstm->int8_scale_term = 2; + lstm->weight_xc_data_int8_scales = weight_xc_data_int8_scales; + lstm->weight_hc_data_int8_scales = weight_hc_data_int8_scales; + } + + return 0; +} + +int NetQuantize::quantize_gru() +{ + for (size_t i = 0; i < layers.size(); i++) + { + if (layers[i]->type != "GRU") + continue; + + // GRU - quantize weight from fp32 to int8 + ncnn::GRU* gru = (ncnn::GRU*)layers[i]; + + fprintf(stderr, "quantize_gru %s\n", gru->name.c_str()); + + // TODO move to ncnn2table + const int num_directions = gru->direction == 2 ? 2 : 1; + const int size = gru->weight_data_size / num_directions / gru->num_output / 3; + + ncnn::Mat weight_xc_data_int8_scales(gru->num_output * 3 * num_directions); + ncnn::Mat weight_hc_data_int8_scales(gru->num_output * 3 * num_directions); + + for (int d = 0; d < num_directions; d++) + { + for (int q = 0; q < gru->num_output * 3; q++) + { + { + const float* weight_xc_ptr = gru->weight_xc_data.channel(d).row(q); + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(weight_xc_ptr[i])); + } + weight_xc_data_int8_scales[d * gru->num_output * 3 + q] = 127 / absmax; + } + + { + const float* weight_hc_ptr = gru->weight_hc_data.channel(d).row(q); + float absmax = 0.f; + for (int i = 0; i < size; i++) + { + absmax = std::max(absmax, (float)fabs(weight_hc_ptr[i])); + } + weight_hc_data_int8_scales[d * gru->num_output * 3 + q] = 127 / absmax; + } + } + } + + { + ncnn::Mat weight_xc_data_r2 = gru->weight_xc_data.reshape(size, gru->num_output * 3 * num_directions); + + ncnn::Mat weight_xc_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = gru->weight_xc_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(weight_xc_data_r2, weight_xc_data_int8, weight_xc_data_int8_scales, opt_q); + if (weight_xc_data_int8.empty()) + return -100; + + gru->weight_xc_data = weight_xc_data_int8.reshape(size * gru->num_output * 3 * num_directions); + } + { + ncnn::Mat weight_hc_data_r2 = gru->weight_hc_data.reshape(gru->num_output, gru->num_output * 3 * num_directions); + + ncnn::Mat weight_hc_data_int8; + + ncnn::Option opt_q = opt; + opt_q.blob_allocator = gru->weight_hc_data.allocator; + opt_q.use_packing_layout = false; + ncnn::quantize_to_int8(weight_hc_data_r2, weight_hc_data_int8, weight_hc_data_int8_scales, opt_q); + if (weight_hc_data_int8.empty()) + return -100; + + gru->weight_hc_data = weight_hc_data_int8.reshape(gru->num_output * gru->num_output * 3 * num_directions); + } + + gru->int8_scale_term = 2; + gru->weight_xc_data_int8_scales = weight_xc_data_int8_scales; + gru->weight_hc_data_int8_scales = weight_hc_data_int8_scales; + } + + return 0; +} + int NetQuantize::fuse_requantize() { const size_t layer_count = layers.size(); @@ -517,7 +767,7 @@ int NetQuantize::fuse_requantize() int main(int argc, char** argv) { - if (argc != 6) + if (argc != 5 && argc != 6) { fprintf(stderr, "usage: %s [inparam] [inbin] [outparam] [outbin] [calibration table]\n", argv[0]); return -1; @@ -527,7 +777,7 @@ int main(int argc, char** argv) const char* inbin = argv[2]; const char* outparam = argv[3]; const char* outbin = argv[4]; - const char* int8scale_table_path = argv[5]; + const char* int8scale_table_path = argc == 6 ? argv[5] : NULL; NetQuantize quantizer; @@ -556,6 +806,10 @@ int main(int argc, char** argv) quantizer.quantize_convolutiondepthwise(); quantizer.quantize_innerproduct(); + quantizer.quantize_rnn(); + quantizer.quantize_lstm(); + quantizer.quantize_gru(); + quantizer.fuse_requantize(); quantizer.save(outparam, outbin);