From 9d14ae750d5a41d8fd5c0ae824f502bdb67331c8 Mon Sep 17 00:00:00 2001 From: Wei-Chen Wang Date: Fri, 19 Apr 2024 15:52:48 -0400 Subject: [PATCH] Fix bugs for ARM platforms (#103) --- kernels/neon/matmul_neon_fp32.cc | 80 +++++++++++++++------------ kernels/neon/matmul_neon_int8_int4.cc | 9 ++- llm/src/ops/BMM_F32T.cc | 4 +- llm/src/ops/linear.cc | 2 +- 4 files changed, 57 insertions(+), 38 deletions(-) diff --git a/kernels/neon/matmul_neon_fp32.cc b/kernels/neon/matmul_neon_fp32.cc index cd8c098..e056b92 100644 --- a/kernels/neon/matmul_neon_fp32.cc +++ b/kernels/neon/matmul_neon_fp32.cc @@ -3,10 +3,13 @@ #include #include #include -#include // #include #include +#ifdef USE_ACCELERATE +#include +#endif + #include "common.h" #include "../matmul.h" #include "pthread_pool.h" @@ -38,6 +41,29 @@ void fp32_ref_matmul(const struct matmul_params *params) { } } +void fp32_ref_matmul_bias(const struct matmul_params *params) { + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *bias = params->bias.data_ptr; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + + assert(A->column == B->row); + assert(C->row == A->row); + assert(C->column == B->column); + int m = A->row, n = B->column, k = A->column; + + for (int i = 0; i < m; i++) { + for (int j = 0; j < n; j++) { + float acc = 0; + for (int kk = 0; kk < k; kk++) { + acc += data_A[i * k + kk] * data_B[j * k + kk]; + } + acc = acc + bias[j]; + data_C[i * n + j] = acc; + } + } +} + +#ifdef USE_ACCELERATE inline void fp32_matmul_transposed_cblas_gemm(const struct matmul_params *params) { const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; @@ -55,11 +81,6 @@ inline void fp32_matmul_transposed_cblas_gemm(const struct matmul_params *params 0.0f, data_C, n); } -void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) { - // fp32_ref_matmul(params); - fp32_matmul_transposed_cblas_gemm(params); -} - inline void fp32_matmul_untransposed_cblas_gemm(const struct matmul_params *params) { const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; @@ -76,32 +97,6 @@ inline void fp32_matmul_untransposed_cblas_gemm(const struct matmul_params *para 0.0f, data_C, n); } -void MatmulOperator::mat_mul_accelerator_untransposed_fastover_column(const struct matmul_params *params) { - fp32_matmul_untransposed_cblas_gemm(params); -} - -void fp32_ref_matmul_bias(const struct matmul_params *params) { - const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; - float *bias = params->bias.data_ptr; - float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; - - assert(A->column == B->row); - assert(C->row == A->row); - assert(C->column == B->column); - int m = A->row, n = B->column, k = A->column; - - for (int i = 0; i < m; i++) { - for (int j = 0; j < n; j++) { - float acc = 0; - for (int kk = 0; kk < k; kk++) { - acc += data_A[i * k + kk] * data_B[j * k + kk]; - } - acc = acc + bias[j]; - data_C[i * n + j] = acc; - } - } -} - void fp32_matmul_bias_cblas_gemm(const struct matmul_params *params) { // struct fp32_thread_args* mat_args = (struct fp32_thread_args*)args; const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; @@ -123,6 +118,21 @@ void fp32_matmul_bias_cblas_gemm(const struct matmul_params *params) { vDSP_vadd(bias, 1, data_C + i * n, 1, data_C + i * n, 1, n); } } +#endif + +void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) { +#ifdef USE_ACCELERATE + fp32_matmul_transposed_cblas_gemm(params); +#else + fp32_ref_matmul(params); +#endif +} + +void MatmulOperator::mat_mul_accelerator_untransposed_fastover_column(const struct matmul_params *params) { +#ifdef USE_ACCELERATE + fp32_matmul_untransposed_cblas_gemm(params); +#endif +} inline static void* fp32_matmul_bias_optimized_gemm(void* args) { struct fp32_thread_args* mat_args = (struct fp32_thread_args*)args; @@ -251,9 +261,11 @@ inline static void* fp32_matmul_bias_optimized_gemm(void* args) { } void MatmulOperator::mat_mul_accelerator_transposed_fastover_column_bias(const struct matmul_params *params) { - // fp32_ref_matmul_bias(params); - +#ifdef USE_ACCELERATE fp32_matmul_bias_cblas_gemm(params); +#else + fp32_ref_matmul_bias(params); +#endif // int i, j, k; // const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; diff --git a/kernels/neon/matmul_neon_int8_int4.cc b/kernels/neon/matmul_neon_int8_int4.cc index 2d7c134..8b5bdf4 100644 --- a/kernels/neon/matmul_neon_int8_int4.cc +++ b/kernels/neon/matmul_neon_int8_int4.cc @@ -4,9 +4,12 @@ #include #include #include -#include #include +#ifdef USE_ACCELERATE +#include +#endif + #include "../matmul.h" #include "common.h" #include "pthread_pool.h" @@ -1265,6 +1268,7 @@ static void* matmul_int8_int4_no_offset_over_column_packed(void* args) { return NULL; } +#ifdef USE_ACCELERATE inline static void* fp32_matmul_transposed_cblas_gemm(void* args) { struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; const struct matmul_params* params = mat_args->params; @@ -1286,6 +1290,7 @@ inline static void* fp32_matmul_transposed_cblas_gemm(void* args) { return NULL; } +#endif namespace matmul { void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params* params) { @@ -1433,6 +1438,7 @@ void MatmulOperator::gemm_accelerator_int8_int4_fast_no_offset_v2(struct matmul_ pool_wait(pool); }; +#ifdef USE_ACCELERATE void MatmulOperator::cblas_gemm_accelerator_no_offset(struct matmul_params* params) { int i, j, k; const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; @@ -1470,5 +1476,6 @@ void MatmulOperator::cblas_gemm_accelerator_no_offset(struct matmul_params* para // Join threads pool_wait(pool); }; +#endif } // namespace matmul diff --git a/llm/src/ops/BMM_F32T.cc b/llm/src/ops/BMM_F32T.cc index 81cdee4..bff8b7d 100644 --- a/llm/src/ops/BMM_F32T.cc +++ b/llm/src/ops/BMM_F32T.cc @@ -38,7 +38,7 @@ void BMM_F32T::forward(const Matrix3D &a, const Matrix3D &weight, // op.mat_mul_transposed_fastover_column((const struct matmul_params // *)¶ms); // else -#ifdef QM_ARM +#ifdef USE_ACCELERATE op.mat_mul_accelerator_transposed_fastover_column(¶ms); #else op.mat_mul_transposed(¶ms); // TODO: optimize this @@ -87,7 +87,7 @@ void BMM_F32T::forward_weight_untransposed(const Matrix3D &a, const Matri for (int i = 0; i < m * n * a.m_dim_x; i++) { params.C.data_ptr[i] = 0; } -#ifdef QM_ARM +#ifdef USE_ACCELERATE for (int bz = 0; bz < a.m_dim_x; bz++) { op.mat_mul_accelerator_untransposed_fastover_column(¶ms); params.A.data_ptr += m * k; diff --git a/llm/src/ops/linear.cc b/llm/src/ops/linear.cc index 863b997..3be0053 100644 --- a/llm/src/ops/linear.cc +++ b/llm/src/ops/linear.cc @@ -26,7 +26,7 @@ void linear(Matrix3D &a, Matrix3D &b, Matrix3D &c) { } } -#ifdef QM_ARM +#ifdef USE_ACCELERATE #define MAX_WEIGHT_BUFFER 32000 * 4096 static float *w_fp32; void Linear_FP_int4::initialize_weight_memory() {