diff --git a/README.md b/README.md index b4829f0b..331bd8ed 100644 --- a/README.md +++ b/README.md @@ -151,7 +151,7 @@ TinyChatEngine offers versatile capabilities suitable for various applications. - Start the speech-to-speech chat locally. ```bash - ./chat -v # chat.exe -v on Windows + ./voicechat # chat.exe -v on Windows ``` - If you encounter any issues or errors during setup, please explore [here](llm/application/README.md) to follow the step-by-step guide to debug. @@ -159,7 +159,7 @@ TinyChatEngine offers versatile capabilities suitable for various applications. ## Deploy vision language model (VLM) chatbot with TinyChatEngine -TinyChatEngine supports not only LLM but also VLM. We introduce a sophisticated text/voice chatbot for VLM. Here, we provide very easy-to-follow instructions to deploy vision language model chatbot (VILA-7B) with TinyChatEngine. +TinyChatEngine supports not only LLM but also VLM. We introduce a sophisticated text/voice chatbot for VLM. Here, we provide easy-to-follow instructions to deploy vision language model chatbot (VILA-7B) with TinyChatEngine. We recommend using M1/M2 MacBooks for this VLM feature. - Follow the instructions above to setup the basic environment, i.e., [Prerequisites](#prerequisites) and [Step-by-step to Deploy LLaMA2-7B-chat with TinyChatEngine](#step-by-step-to-deploy-llama2-7b-chat-with-tinychatengine). @@ -169,6 +169,10 @@ TinyChatEngine supports not only LLM but also VLM. We introduce a sophisticated - (For other OS) Please refer to [here](https://github.com/AnonymouX47/termvisage?tab=readme-ov-file#requirements) to get the appropriate terminal ready. - (Optional) To enable the speech-to-speech chatbot for VLM, please follow the [instruction above](#deploy-speech-to-speech-chatbot-with-tinychatengine-demo) to run the shell script to set up the environment. + ```bash + cd llm + ./voicechat_setup.sh + ``` - Download the quantized VILA-7B model from our model zoo. @@ -184,12 +188,12 @@ TinyChatEngine supports not only LLM but also VLM. We introduce a sophisticated - (For MacOS) Start the chatbot locally. Please use an appropriate terminal (e.g., iTerm2). - Image/Text to text ```bash - ./scripts/vila.sh ../assets/figures/vlm_demo/pedestrian.png + ./vila ../assets/figures/vlm_demo/pedestrian.png ``` - Image/Speech to speech ```bash - ./scripts/voice_vila.sh ../assets/figures/vlm_demo/pedestrian.png + ./voice_vila ../assets/figures/vlm_demo/pedestrian.png ``` - There are several images under the path `../assets/figures/vlm_demo`. Feel free to try different images with VILA on your device! diff --git a/kernels/matmul.h b/kernels/matmul.h index 563adaf8..8c186ad4 100644 --- a/kernels/matmul.h +++ b/kernels/matmul.h @@ -99,15 +99,16 @@ struct thread_args { int start_i, end_i, blk_size; }; - #define MAX(A, B) ((A) > (B) ? (A) : (B)) #define MIN(A, B) ((A) < (B) ? (A) : (B)) + namespace matmul { class MatmulOperator { public: void mat_mul_transposed(const struct matmul_params *params); void mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params); void mat_mul_accelerator_transposed_fastover_column_bias(const struct matmul_params *params); + void mat_mul_accelerator_untransposed_fastover_column(const struct matmul_params *params); // int8 void naive_mat_mul_int8(const struct matmul_params *params); void mat_mul_accelerator_int8_fast_32unroll_over_column(const struct matmul_params *params); @@ -125,6 +126,8 @@ class MatmulOperator { void mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params *params); void gemv_accelerator_int8_int4_fast_no_offset(struct matmul_params *params); void gemm_accelerator_int8_int4_fast_no_offset(struct matmul_params *params); + void gemm_accelerator_int8_int4_fast_no_offset_v2(struct matmul_params *params); + void cblas_gemm_accelerator_no_offset(struct matmul_params *params); void naive_mat_mul_int4(const struct matmul_params *params); void naive_mat_mul_int4_with_offset(const struct matmul_params *params); // cuda diff --git a/kernels/neon/matmul_neon_fp32.cc b/kernels/neon/matmul_neon_fp32.cc index 2a041603..cd8c098b 100644 --- a/kernels/neon/matmul_neon_fp32.cc +++ b/kernels/neon/matmul_neon_fp32.cc @@ -38,25 +38,46 @@ void fp32_ref_matmul(const struct matmul_params *params) { } } -void fp32_matmul_cblas_gemm(const struct matmul_params *params) { +inline void fp32_matmul_transposed_cblas_gemm(const struct matmul_params *params) { + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + float alpha = params->alpha; + + assert(A->column == B->column); + assert(C->row == A->row); + assert(C->column == B->row); + int m = C->row, n = C->column, k = A->column; + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + m, n, k, + alpha, data_A, k, + data_B, k, + 0.0f, data_C, n); +} + +void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) { + // fp32_ref_matmul(params); + fp32_matmul_transposed_cblas_gemm(params); +} + +inline void fp32_matmul_untransposed_cblas_gemm(const struct matmul_params *params) { const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; assert(A->column == B->row); assert(C->row == A->row); assert(C->column == B->column); - int m = A->row, n = B->column, k = A->column; + int m = C->row, n = C->column, k = A->column; cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, - m, n, k, - 1.0f, data_A, m, - data_B, k, - 0.0f, data_C, m); + m, n, k, + 1.0f, data_A, k, + data_B, n, + 0.0f, data_C, n); } -void MatmulOperator::mat_mul_accelerator_transposed_fastover_column(const struct matmul_params *params) { - fp32_ref_matmul(params); - // fp32_matmul_cblas_gemm(params); +void MatmulOperator::mat_mul_accelerator_untransposed_fastover_column(const struct matmul_params *params) { + fp32_matmul_untransposed_cblas_gemm(params); } void fp32_ref_matmul_bias(const struct matmul_params *params) { diff --git a/kernels/neon/matmul_neon_int8_int4.cc b/kernels/neon/matmul_neon_int8_int4.cc index ba463b0b..2d7c134e 100644 --- a/kernels/neon/matmul_neon_int8_int4.cc +++ b/kernels/neon/matmul_neon_int8_int4.cc @@ -2,17 +2,17 @@ #include #include #include - #include #include +#include +#include #include "../matmul.h" #include "common.h" - #include "pthread_pool.h" struct a8w4_thread_args { - int start_j, end_j; + int start_i, end_i, start_j, end_j, tile_size; const struct matmul_params* params; }; @@ -76,6 +76,83 @@ void quantize_fp32_to_int8(float* A, int8_t* qA, float* sA, int size, int block_ } } +void dequantize_int4_to_fp32(uint8_t* qW, float* W, float* sW, int size, int block_size) { + assert(size % block_size == 0); + assert(block_size == 32); + int num_block = size / 32; + + const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + const int8x16_t offsets = vdupq_n_s8(8); + float* w_start_fp32 = &W[0]; + for (int i = 0; i < num_block; i++) { + const unsigned char* w_start = &qW[i * 16]; + float* s_w = &sW[i]; + float s_0 = s_w[0]; + + const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + + // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + + // apply offset + w0_low = vsubq_s8(w0_low, offsets); + w0_high = vsubq_s8(w0_high, offsets); + + // Step 1: Split each int8x16_t vector into two int8x8_t vectors + int8x8_t w0_low_low = vget_low_s8(w0_low); + int8x8_t w0_low_high = vget_high_s8(w0_low); + int8x8_t w0_high_low = vget_low_s8(w0_high); + int8x8_t w0_high_high = vget_high_s8(w0_high); + + // Step 2: Extend each int8x8_t vector to int16x8_t + int16x8_t w0_ll_ext = vmovl_s8(w0_low_low); + int16x8_t w0_lh_ext = vmovl_s8(w0_low_high); + int16x8_t w0_hl_ext = vmovl_s8(w0_high_low); + int16x8_t w0_hh_ext = vmovl_s8(w0_high_high); + + // Step 3: Further extend int16x8_t to int32x4_t and then convert to float32x4_t + float32x4_t w0_ll_f = vcvtq_f32_s32(vmovl_s16(vget_low_s16(w0_ll_ext))); + float32x4_t w0_lh_f = vcvtq_f32_s32(vmovl_s16(vget_high_s16(w0_ll_ext))); + float32x4_t w0_hl_f = vcvtq_f32_s32(vmovl_s16(vget_low_s16(w0_lh_ext))); + float32x4_t w0_hh_f = vcvtq_f32_s32(vmovl_s16(vget_high_s16(w0_lh_ext))); + float32x4_t w1_ll_f = vcvtq_f32_s32(vmovl_s16(vget_low_s16(w0_hl_ext))); + float32x4_t w1_lh_f = vcvtq_f32_s32(vmovl_s16(vget_high_s16(w0_hl_ext))); + float32x4_t w1_hl_f = vcvtq_f32_s32(vmovl_s16(vget_low_s16(w0_hh_ext))); + float32x4_t w1_hh_f = vcvtq_f32_s32(vmovl_s16(vget_high_s16(w0_hh_ext))); + + float32x4_t sumv0_ll = vmulq_n_f32(w0_ll_f, s_0); + float32x4_t sumv0_lh = vmulq_n_f32(w0_lh_f, s_0); + float32x4_t sumv0_hl = vmulq_n_f32(w0_hl_f, s_0); + float32x4_t sumv0_hh = vmulq_n_f32(w0_hh_f, s_0); + float32x4_t sumv1_ll = vmulq_n_f32(w1_ll_f, s_0); + float32x4_t sumv1_lh = vmulq_n_f32(w1_lh_f, s_0); + float32x4_t sumv1_hl = vmulq_n_f32(w1_hl_f, s_0); + float32x4_t sumv1_hh = vmulq_n_f32(w1_hh_f, s_0); + + vst1q_f32(w_start_fp32, sumv0_ll); + w_start_fp32 += 4; + vst1q_f32(w_start_fp32, sumv0_lh); + w_start_fp32 += 4; + vst1q_f32(w_start_fp32, sumv0_hl); + w_start_fp32 += 4; + vst1q_f32(w_start_fp32, sumv0_hh); + w_start_fp32 += 4; + vst1q_f32(w_start_fp32, sumv1_ll); + w_start_fp32 += 4; + vst1q_f32(w_start_fp32, sumv1_lh); + w_start_fp32 += 4; + vst1q_f32(w_start_fp32, sumv1_hl); + w_start_fp32 += 4; + vst1q_f32(w_start_fp32, sumv1_hh); + w_start_fp32 += 4; + } +} + void matmul_int8_int4_no_offset(struct matmul_params* params) { int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size; assert(params->block_size == 32); @@ -330,6 +407,551 @@ inline static void* gemv_int8_int4_no_offset_over_column_unroll128(void* args) { return NULL; } +inline static void* gemm_int8_int4_no_offset_over_column_unroll128(void* args) { + struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; + const struct matmul_params* params = mat_args->params; + int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size; + const int num_block = k / block_size; + int TILE_SIZE = mat_args->tile_size; + + // assert((mat_args->end_i - mat_args->start_i) % TILE_SIZE == 0); + assert(k % TILE_SIZE == 0); + assert(n % TILE_SIZE == 0); + // assert(TILE_SIZE % 4 == 0); + + for (int ti = mat_args->start_i; ti < mat_args->end_i; ti += TILE_SIZE) { + for (int tj = 0; tj < n; tj += TILE_SIZE) { + for (int i = ti; i < ti + TILE_SIZE; i++) { + for (int j = tj; j < tj + TILE_SIZE; j++) { + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + float32x4_t sumv2 = vdupq_n_f32(0.0f); + float32x4_t sumv3 = vdupq_n_f32(0.0f); + const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; + const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; + float* s_a = ¶ms->A_scales[i * k / 32]; + float* s_w = ¶ms->scales[j * k / 32]; + + const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + const int8x16_t offsets = vdupq_n_s8(8); + for (int q = 0; q < num_block; q += 4) { + int32x4_t int_sum0 = vdupq_n_s32(0); + int32x4_t int_sum1 = vdupq_n_s32(0); + int32x4_t int_sum2 = vdupq_n_s32(0); + int32x4_t int_sum3 = vdupq_n_s32(0); + float s_0 = *s_a++ * *s_w++; + float s_1 = *s_a++ * *s_w++; + float s_2 = *s_a++ * *s_w++; + float s_3 = *s_a++ * *s_w++; + + const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight + const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight + const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight + w_start += 64; + + // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); + int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); + int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit)); + int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4)); + int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); + int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + + // apply offset + w0_low = vsubq_s8(w0_low, offsets); + w0_high = vsubq_s8(w0_high, offsets); + w1_low = vsubq_s8(w1_low, offsets); + w1_high = vsubq_s8(w1_high, offsets); + w2_low = vsubq_s8(w2_low, offsets); + w2_high = vsubq_s8(w2_high, offsets); + w3_low = vsubq_s8(w3_low, offsets); + w3_high = vsubq_s8(w3_high, offsets); + + // load 64 8-bit activation + const int8x16_t a0 = vld1q_s8(a_start); + const int8x16_t a1 = vld1q_s8(a_start + 16); + const int8x16_t a2 = vld1q_s8(a_start + 32); + const int8x16_t a3 = vld1q_s8(a_start + 48); + const int8x16_t a4 = vld1q_s8(a_start + 64); + const int8x16_t a5 = vld1q_s8(a_start + 80); + const int8x16_t a6 = vld1q_s8(a_start + 96); + const int8x16_t a7 = vld1q_s8(a_start + 112); + a_start += 128; + + // dot product into int32x4_t + int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); + int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); + int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); + int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); + int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4); + int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5); + int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6); + int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); + sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2); + sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3); + } + if (params->bias.data_ptr) { + params->C.data_ptr[i * n + j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + else { + params->C.data_ptr[i * n + j] = + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + + + ////////////////////////////////////////// + // float32x4_t sumv0 = vdupq_n_f32(0.0f); + // float32x4_t sumv1 = vdupq_n_f32(0.0f); + // float32x4_t sumv2 = vdupq_n_f32(0.0f); + // float32x4_t sumv3 = vdupq_n_f32(0.0f); + // const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; + // const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; + // float* s_a = ¶ms->A_scales[i * k / 32]; + // float* s_w = ¶ms->scales[j * k / 32]; + + // const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + // const int8x16_t offsets = vdupq_n_s8(8); + // for (int q = 0; q < num_block; q += 4) { + // int32x4_t int_sum0 = vdupq_n_s32(0); + // int32x4_t int_sum1 = vdupq_n_s32(0); + // int32x4_t int_sum2 = vdupq_n_s32(0); + // int32x4_t int_sum3 = vdupq_n_s32(0); + // float s_0 = *s_a++ * *s_w++; + // float s_1 = *s_a++ * *s_w++; + // float s_2 = *s_a++ * *s_w++; + // float s_3 = *s_a++ * *s_w++; + + // const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + // const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight + // const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight + // const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight + // w_start += 64; + + // // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + // int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + // int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + // int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); + // int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); + // int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit)); + // int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4)); + // int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); + // int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + + // // apply offset + // w0_low = vsubq_s8(w0_low, offsets); + // w0_high = vsubq_s8(w0_high, offsets); + // w1_low = vsubq_s8(w1_low, offsets); + // w1_high = vsubq_s8(w1_high, offsets); + // w2_low = vsubq_s8(w2_low, offsets); + // w2_high = vsubq_s8(w2_high, offsets); + // w3_low = vsubq_s8(w3_low, offsets); + // w3_high = vsubq_s8(w3_high, offsets); + + // // load 64 8-bit activation + // const int8x16_t a0 = vld1q_s8(a_start); + // const int8x16_t a1 = vld1q_s8(a_start + 16); + // const int8x16_t a2 = vld1q_s8(a_start + 32); + // const int8x16_t a3 = vld1q_s8(a_start + 48); + // const int8x16_t a4 = vld1q_s8(a_start + 64); + // const int8x16_t a5 = vld1q_s8(a_start + 80); + // const int8x16_t a6 = vld1q_s8(a_start + 96); + // const int8x16_t a7 = vld1q_s8(a_start + 112); + // a_start += 128; + + // // dot product into int32x4_t + // int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); + // int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); + // int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); + // int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); + // int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4); + // int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5); + // int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6); + // int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7); + + // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); + // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); + // sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2); + // sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3); + // } + // if (params->bias.data_ptr) { + // params->C.data_ptr[i * n + j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + + // vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + // } + // else { + // params->C.data_ptr[i * n + j] = + // vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + // } + + + ////////////////////////////////////////// + // float32x4_t sumv0 = vdupq_n_f32(0.0f); + // float32x4_t sumv1 = vdupq_n_f32(0.0f); + // const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; + // const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; + // float* s_a = ¶ms->A_scales[i * k / 32]; + // float* s_w = ¶ms->scales[j * k / 32]; + + // const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + // const int8x16_t offsets = vdupq_n_s8(8); + // for (int q = 0; q < num_block; q += 2) { + // int32x4_t int_sum0 = vdupq_n_s32(0); + // int32x4_t int_sum1 = vdupq_n_s32(0); + // float s_0 = *s_a++ * *s_w++; + // float s_1 = *s_a++ * *s_w++; + + // const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + // const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight + // w_start += 32; + + // // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + // int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + // int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + // int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); + // int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); + + // // load 64 8-bit activation + // const int8x16_t a0 = vld1q_s8(a_start); + // const int8x16_t a1 = vld1q_s8(a_start + 16); + // const int8x16_t a2 = vld1q_s8(a_start + 32); + // const int8x16_t a3 = vld1q_s8(a_start + 48); + // a_start += 64; + + // // apply offset + // w0_low = vsubq_s8(w0_low, offsets); + // w0_high = vsubq_s8(w0_high, offsets); + // w1_low = vsubq_s8(w1_low, offsets); + // w1_high = vsubq_s8(w1_high, offsets); + + // // dot product into int32x4_t + // int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); + // int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); + // int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); + // int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); + + // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); + // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); + // } + // params->C.data_ptr[i * n + j] = vaddvq_f32(sumv0) + vaddvq_f32(sumv1); + } + } + } + } + + // Leftover rows w/o tiling + int left_start_i = mat_args->start_i + ((mat_args->end_i - mat_args->start_i) / TILE_SIZE) * TILE_SIZE; + for (int i = left_start_i; i < mat_args->end_i; i++) { + for (int j = 0; j < n; j++) { + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + float32x4_t sumv2 = vdupq_n_f32(0.0f); + float32x4_t sumv3 = vdupq_n_f32(0.0f); + const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; + const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; + float* s_a = ¶ms->A_scales[i * k / 32]; + float* s_w = ¶ms->scales[j * k / 32]; + + const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + const int8x16_t offsets = vdupq_n_s8(8); + for (int q = 0; q < num_block; q += 4) { + int32x4_t int_sum0 = vdupq_n_s32(0); + int32x4_t int_sum1 = vdupq_n_s32(0); + int32x4_t int_sum2 = vdupq_n_s32(0); + int32x4_t int_sum3 = vdupq_n_s32(0); + float s_0 = *s_a++ * *s_w++; + float s_1 = *s_a++ * *s_w++; + float s_2 = *s_a++ * *s_w++; + float s_3 = *s_a++ * *s_w++; + + const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight + const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight + const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight + w_start += 64; + + // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); + int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); + int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit)); + int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4)); + int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); + int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + + // apply offset + w0_low = vsubq_s8(w0_low, offsets); + w0_high = vsubq_s8(w0_high, offsets); + w1_low = vsubq_s8(w1_low, offsets); + w1_high = vsubq_s8(w1_high, offsets); + w2_low = vsubq_s8(w2_low, offsets); + w2_high = vsubq_s8(w2_high, offsets); + w3_low = vsubq_s8(w3_low, offsets); + w3_high = vsubq_s8(w3_high, offsets); + + // load 64 8-bit activation + const int8x16_t a0 = vld1q_s8(a_start); + const int8x16_t a1 = vld1q_s8(a_start + 16); + const int8x16_t a2 = vld1q_s8(a_start + 32); + const int8x16_t a3 = vld1q_s8(a_start + 48); + const int8x16_t a4 = vld1q_s8(a_start + 64); + const int8x16_t a5 = vld1q_s8(a_start + 80); + const int8x16_t a6 = vld1q_s8(a_start + 96); + const int8x16_t a7 = vld1q_s8(a_start + 112); + a_start += 128; + + // dot product into int32x4_t + int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); + int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); + int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); + int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); + int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4); + int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5); + int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6); + int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); + sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2); + sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3); + } + if (params->bias.data_ptr) { + params->C.data_ptr[i * n + j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + else { + params->C.data_ptr[i * n + j] = + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + } + } + + return NULL; +} + +inline static void* gemm_int8_int4_no_offset_over_column_unroll128_v2(void* args) { + struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; + const struct matmul_params* params = mat_args->params; + int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size; + const int num_block = k / block_size; + int TILE_SIZE = mat_args->tile_size; + + // assert((mat_args->end_j - mat_args->start_j) % TILE_SIZE == 0); + assert(k % TILE_SIZE == 0); + assert(n % TILE_SIZE == 0); + // assert(TILE_SIZE % 4 == 0); + + for (int ti = 0; ti < m; ti += TILE_SIZE) { + for (int tj = mat_args->start_j; tj < mat_args->end_j; tj += TILE_SIZE) { + for (int i = ti; i < ti + TILE_SIZE; i++) { + for (int j = tj; j < tj + TILE_SIZE; j++) { + float32x4_t sumv0 = vdupq_n_f32(0.0f); + float32x4_t sumv1 = vdupq_n_f32(0.0f); + float32x4_t sumv2 = vdupq_n_f32(0.0f); + float32x4_t sumv3 = vdupq_n_f32(0.0f); + const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; + const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; + float* s_a = ¶ms->A_scales[i * k / 32]; + float* s_w = ¶ms->scales[j * k / 32]; + + const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + const int8x16_t offsets = vdupq_n_s8(8); + for (int q = 0; q < num_block; q += 4) { + int32x4_t int_sum0 = vdupq_n_s32(0); + int32x4_t int_sum1 = vdupq_n_s32(0); + int32x4_t int_sum2 = vdupq_n_s32(0); + int32x4_t int_sum3 = vdupq_n_s32(0); + float s_0 = *s_a++ * *s_w++; + float s_1 = *s_a++ * *s_w++; + float s_2 = *s_a++ * *s_w++; + float s_3 = *s_a++ * *s_w++; + + const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight + const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight + const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight + w_start += 64; + + // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); + int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); + int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit)); + int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4)); + int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); + int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + + // apply offset + w0_low = vsubq_s8(w0_low, offsets); + w0_high = vsubq_s8(w0_high, offsets); + w1_low = vsubq_s8(w1_low, offsets); + w1_high = vsubq_s8(w1_high, offsets); + w2_low = vsubq_s8(w2_low, offsets); + w2_high = vsubq_s8(w2_high, offsets); + w3_low = vsubq_s8(w3_low, offsets); + w3_high = vsubq_s8(w3_high, offsets); + + // load 64 8-bit activation + const int8x16_t a0 = vld1q_s8(a_start); + const int8x16_t a1 = vld1q_s8(a_start + 16); + const int8x16_t a2 = vld1q_s8(a_start + 32); + const int8x16_t a3 = vld1q_s8(a_start + 48); + const int8x16_t a4 = vld1q_s8(a_start + 64); + const int8x16_t a5 = vld1q_s8(a_start + 80); + const int8x16_t a6 = vld1q_s8(a_start + 96); + const int8x16_t a7 = vld1q_s8(a_start + 112); + a_start += 128; + + // dot product into int32x4_t + int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); + int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); + int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); + int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); + int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4); + int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5); + int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6); + int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7); + + sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); + sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); + sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2); + sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3); + } + if (params->bias.data_ptr) { + params->C.data_ptr[i * n + j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + else { + params->C.data_ptr[i * n + j] = + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + } + } + } + } + } + + // // Leftover rows w/o tiling + int left_start_j = mat_args->start_j + ((mat_args->end_j - mat_args->start_j) / TILE_SIZE) * TILE_SIZE; + // for (int i = 0; i < m; i++) { + // for (int j = left_start_j; i < mat_args->end_j; j++) { + // float32x4_t sumv0 = vdupq_n_f32(0.0f); + // float32x4_t sumv1 = vdupq_n_f32(0.0f); + // float32x4_t sumv2 = vdupq_n_f32(0.0f); + // float32x4_t sumv3 = vdupq_n_f32(0.0f); + // const unsigned char* w_start = ¶ms->B.int4_data_ptr[j * k / 2]; + // const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; + // float* s_a = ¶ms->A_scales[i * k / 32]; + // float* s_w = ¶ms->scales[j * k / 32]; + + // const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); + // const int8x16_t offsets = vdupq_n_s8(8); + // for (int q = 0; q < num_block; q += 4) { + // int32x4_t int_sum0 = vdupq_n_s32(0); + // int32x4_t int_sum1 = vdupq_n_s32(0); + // int32x4_t int_sum2 = vdupq_n_s32(0); + // int32x4_t int_sum3 = vdupq_n_s32(0); + // float s_0 = *s_a++ * *s_w++; + // float s_1 = *s_a++ * *s_w++; + // float s_2 = *s_a++ * *s_w++; + // float s_3 = *s_a++ * *s_w++; + + // const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight + // const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight + // const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight + // const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight + // w_start += 64; + + // // Quantization Method QM_ARM, convert 64 4-bit to 64 8-bit + // // sequential: (0, 1), (2, 3), (4, 5), (6, 7)... : 128 bit + // // expected layout of inB: (0, 16), (1, 17), (2, 18), (3, 19)... + // // low; (0, 0), (1, 0), (2, 0), (3, 0) ... + // // high: (16, 0), (17, 0), (18, 0), (19, 0) ... + // int8x16_t w0_low = vreinterpretq_s8_u8(vandq_u8(w0, mask_low4bit)); + // int8x16_t w0_high = vreinterpretq_s8_u8(vshrq_n_u8(w0, 4)); + // int8x16_t w1_low = vreinterpretq_s8_u8(vandq_u8(w1, mask_low4bit)); + // int8x16_t w1_high = vreinterpretq_s8_u8(vshrq_n_u8(w1, 4)); + // int8x16_t w2_low = vreinterpretq_s8_u8(vandq_u8(w2, mask_low4bit)); + // int8x16_t w2_high = vreinterpretq_s8_u8(vshrq_n_u8(w2, 4)); + // int8x16_t w3_low = vreinterpretq_s8_u8(vandq_u8(w3, mask_low4bit)); + // int8x16_t w3_high = vreinterpretq_s8_u8(vshrq_n_u8(w3, 4)); + + // // apply offset + // w0_low = vsubq_s8(w0_low, offsets); + // w0_high = vsubq_s8(w0_high, offsets); + // w1_low = vsubq_s8(w1_low, offsets); + // w1_high = vsubq_s8(w1_high, offsets); + // w2_low = vsubq_s8(w2_low, offsets); + // w2_high = vsubq_s8(w2_high, offsets); + // w3_low = vsubq_s8(w3_low, offsets); + // w3_high = vsubq_s8(w3_high, offsets); + + // // load 64 8-bit activation + // const int8x16_t a0 = vld1q_s8(a_start); + // const int8x16_t a1 = vld1q_s8(a_start + 16); + // const int8x16_t a2 = vld1q_s8(a_start + 32); + // const int8x16_t a3 = vld1q_s8(a_start + 48); + // const int8x16_t a4 = vld1q_s8(a_start + 64); + // const int8x16_t a5 = vld1q_s8(a_start + 80); + // const int8x16_t a6 = vld1q_s8(a_start + 96); + // const int8x16_t a7 = vld1q_s8(a_start + 112); + // a_start += 128; + + // // dot product into int32x4_t + // int_sum0 = my_vdotq_s32(int_sum0, w0_low, a0); + // int_sum0 = my_vdotq_s32(int_sum0, w0_high, a1); + // int_sum1 = my_vdotq_s32(int_sum1, w1_low, a2); + // int_sum1 = my_vdotq_s32(int_sum1, w1_high, a3); + // int_sum2 = my_vdotq_s32(int_sum2, w2_low, a4); + // int_sum2 = my_vdotq_s32(int_sum2, w2_high, a5); + // int_sum3 = my_vdotq_s32(int_sum3, w3_low, a6); + // int_sum3 = my_vdotq_s32(int_sum3, w3_high, a7); + + // sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); + // sumv1 = vmlaq_n_f32(sumv1, vcvtq_f32_s32(int_sum1), s_1); + // sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2); + // sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3); + // } + // if (params->bias.data_ptr) { + // params->C.data_ptr[i * n + j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + + // vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + // } + // else { + // params->C.data_ptr[i * n + j] = + // vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3); + // } + // } + // } + + return NULL; +} + inline static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) { struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; const struct matmul_params* params = mat_args->params; @@ -346,7 +968,7 @@ inline static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args) const signed char* a_start = ¶ms->A.int8_data_ptr[i * k]; float* s_a = ¶ms->A_scales[i * k / 32]; float* s_w = ¶ms->scales[j * k / 32]; - + const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); const int8x16_t offsets = vdupq_n_s8(8); for (int q = 0; q < num_block; q += 4) { @@ -643,6 +1265,28 @@ static void* matmul_int8_int4_no_offset_over_column_packed(void* args) { return NULL; } +inline static void* fp32_matmul_transposed_cblas_gemm(void* args) { + struct a8w4_thread_args* mat_args = (struct a8w4_thread_args*)args; + const struct matmul_params* params = mat_args->params; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr + mat_args->start_j * A->column; + float *data_B = B->data_ptr; + float *data_C = C->data_ptr + mat_args->start_j * C->column; + float alpha = params->alpha; + + int n = C->column, k = A->column; + int m = mat_args->end_j - mat_args->start_j; + + cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, + m, n, k, + alpha, data_A, k, + data_B, k, + 0.0f, data_C, n); + + return NULL; +} + namespace matmul { void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_params* params) { int i, j, k; @@ -724,4 +1368,107 @@ void MatmulOperator::gemv_accelerator_int8_int4_fast_no_offset(struct matmul_par // Join threads pool_wait(pool); }; + +void MatmulOperator::gemm_accelerator_int8_int4_fast_no_offset(struct matmul_params* params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + const int block_size = params->block_size; + float *scale = params->scales, *offset = params->offset; + assert(params->block_size % 32 == 0); // support block size to be multiply of 32 + assert(A->row == C->row); // support block size to be multiply of 32 + + quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size); + + const int num_thread = params->opt_params.num_thread; + struct a8w4_thread_args threads_args[num_thread]; + assert(params->block_size == 32); // support block size 32 for now + + static void *pool = pool_start(gemm_int8_int4_no_offset_over_column_unroll128, num_thread); + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (params->C.row / num_thread); + if (j == num_thread - 1) { + threads_args[j].end_i = params->C.row; + } else { + threads_args[j].end_i = (j + 1) * (params->C.row / num_thread); + } + threads_args[j].tile_size = 4; + threads_args[j].params = params; + pool_enqueue(pool, &threads_args[j], '\0'); + } + // Join threads + pool_wait(pool); +}; + +void MatmulOperator::gemm_accelerator_int8_int4_fast_no_offset_v2(struct matmul_params* params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + const int block_size = params->block_size; + float *scale = params->scales, *offset = params->offset; + assert(params->block_size % 32 == 0); // support block size to be multiply of 32 + assert(A->row == C->row); // support block size to be multiply of 32 + + quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size); + + const int num_thread = params->opt_params.num_thread; + struct a8w4_thread_args threads_args[num_thread]; + assert(params->block_size == 32); // support block size 32 for now + + static void *pool = pool_start(gemm_int8_int4_no_offset_over_column_unroll128_v2, num_thread); + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_j = j * (params->C.column / num_thread); + if (j == num_thread - 1) { + threads_args[j].end_j = params->C.column; + } else { + threads_args[j].end_j = (j + 1) * (params->C.column / num_thread); + } + threads_args[j].tile_size = 4; + threads_args[j].params = params; + pool_enqueue(pool, &threads_args[j], '\0'); + } + // Join threads + pool_wait(pool); +}; + +void MatmulOperator::cblas_gemm_accelerator_no_offset(struct matmul_params* params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + const int block_size = params->block_size; + float *scale = params->scales, *offset = params->offset; + assert(params->block_size % 32 == 0); // support block size to be multiply of 32 + assert(A->row == C->row); // support block size to be multiply of 32 + + dequantize_int4_to_fp32(B->int4_data_ptr, B->data_ptr, params->scales, A->column * C->column, block_size); + + const int num_thread = params->opt_params.num_thread; + struct a8w4_thread_args threads_args[num_thread]; + assert(params->block_size == 32); // support block size 32 for now + + // mat_mul_accelerator_transposed_fastover_column(params); + + static void *pool = pool_start(fp32_matmul_transposed_cblas_gemm, num_thread); + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_j = j * (params->C.row / num_thread); + if (j == num_thread - 1) { + threads_args[j].end_j = params->C.row; + } else { + threads_args[j].end_j = (j + 1) * (params->C.row / num_thread); + } + // ¶ms->A.data_ptr = threads_args[j].start_j * params->A.column; + // params->A.row = threads_args[j].end_j - threads_args[j].start_j; + // ¶ms->C.data_ptr = threads_args[j].start_j * params->C.column; + // params->C.row = threads_args[j].end_j - threads_args[j].start_j; + // threads_args[j].tile_size = 4; + threads_args[j].params = params; + pool_enqueue(pool, &threads_args[j], '\0'); + } + // Join threads + pool_wait(pool); +}; + } // namespace matmul diff --git a/llm/Makefile b/llm/Makefile index f2685c18..94c34bb9 100644 --- a/llm/Makefile +++ b/llm/Makefile @@ -113,7 +113,7 @@ else ifeq ($(shell uname -p),arm) # Use NEON with int8 runtime quantization is faster else LIB_SRC += $(wildcard $(LIB_DIR)/neon/*.cc) - CXXFLAGS += -march=native -DUSE_INT8_INT4_PRODUCT -DQM_ARM -fPIC -march=armv8.2-a -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 + CXXFLAGS += -march=native -DUSE_INT8_INT4_PRODUCT -DQM_ARM -fPIC -march=armv8.2-a -DACCELERATE_NEW_LAPACK -DACCELERATE_LAPACK_ILP64 -DUSE_ACCELERATE LDFLAGS += -framework Accelerate INCLUDE_DIRS += -I/opt/homebrew/opt/boost/include endif diff --git a/llm/application/chat.cc b/llm/application/chat.cc index d96639c7..c51b926d 100644 --- a/llm/application/chat.cc +++ b/llm/application/chat.cc @@ -4,6 +4,7 @@ #include #include "Generate.h" +#include "interface.h" std::map model_config = { {"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B}, @@ -114,7 +115,10 @@ int main(int argc, char* argv[]) { std::string img_path = "images/monalisa.jpg"; Profiler::getInstance().for_demo = true; + // Set prompt color + set_print_yellow(); std::cout << "TinyChatEngine by MIT HAN Lab: https://github.com/mit-han-lab/TinyChatEngine" << std::endl; + if (argc >= 3 && argc <= 5) { auto target_str = argv[1]; target_model = argv[1]; @@ -192,7 +196,7 @@ int main(int argc, char* argv[]) { int format_id = data_format_list[target_data_format]; // Voicechat instructions - if (use_voicechat){ + if (use_voicechat) { std::cout << "You are using the TinyVoiceChat." << std::endl; std::cout << "*Usage instructions*" << std::endl; std::cout << "- Please use this mode in a quiet environment to have a better user experience and avoid speech misdetection." << std::endl; @@ -224,26 +228,38 @@ int main(int argc, char* argv[]) { if (format_id == FP32) { Fp32LlamaForCausalLM model = Fp32LlamaForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; + std::cout << "Finished!" << std::endl << std::endl; // Get input from the user while (true) { std::string input; - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") break; if (instruct) { - std::cout << "ASSISTANT: " << std::endl; + std::cout << "ASSISTANT: "; if (isCodeLLaMA(target_model)) { if (first_prompt) { input = "[INST] " + input + " [/INST] "; @@ -275,26 +291,38 @@ int main(int argc, char* argv[]) { } else if (format_id == INT4) { m_path = "INT4/" + m_path; Int4LlamaForCausalLM model = Int4LlamaForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; + std::cout << "Finished!" << std::endl << std::endl; // Get input from the user while (true) { std::string input; - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") break; if (instruct) { - std::cout << "ASSISTANT: " << std::endl; + std::cout << "ASSISTANT: "; if (isCodeLLaMA(target_model)) { if (first_prompt) { input = "[INST] " + input + " [/INST] "; @@ -347,28 +375,40 @@ int main(int argc, char* argv[]) { if (format_id == FP32) { Fp32GPTBigCodeForCausalLM model = Fp32GPTBigCodeForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; + std::cout << "Finished!" << std::endl << std::endl; // Get input from the user while (true) { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; std::string input; + // set user input color + set_print_red(); std::getline(std::cin, input); std::cout << input; + // reset color + set_print_reset(); GPTBigCodeGenerate(m_path, &model, StarCoder_FP32, input, generation_config, "models/starcoder_vocab.bin", true); } } else if (format_id == INT4) { m_path = "INT4/" + m_path; Int4GPTBigCodeForCausalLM model = Int4GPTBigCodeForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; + std::cout << "Finished!" << std::endl << std::endl; // Get input from the user while (true) { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; std::string input; + // set user input color + set_print_red(); std::getline(std::cin, input); std::cout << input; + // reset color + set_print_reset(); GPTBigCodeGenerate(m_path, &model, StarCoder_INT4, input, generation_config, "models/starcoder_vocab.bin", true); } @@ -380,7 +420,7 @@ int main(int argc, char* argv[]) { int format_id = data_format_list[target_data_format]; // Voicechat instructions - if (use_voicechat){ + if (use_voicechat) { std::cout << "You are using the TinyVoiceChat." << std::endl; std::cout << "*Usage instructions*" << std::endl; std::cout << "- Please use this mode in a quiet environment to have a better user experience and avoid speech misdetection." << std::endl; @@ -416,23 +456,39 @@ int main(int argc, char* argv[]) { while (true) { std::string input; if (prompt_iter == 1) { - std::cout << "Finished!" << std::endl; + // Set prompt color + set_print_yellow(); + std::cout << "Finished!" << std::endl << std::endl; + // reset color + set_print_reset(); } if (prompt_iter > 0) { - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") break; - std::cout << "ASSISTANT: " << std::endl; + std::cout << "ASSISTANT: "; } if (prompt_iter == 0) { @@ -457,24 +513,40 @@ int main(int argc, char* argv[]) { // Get input from the user while (true) { if (prompt_iter == 1) { - std::cout << "Finished!" << std::endl; + // Set prompt color + set_print_yellow(); + std::cout << "Finished!" << std::endl << std::endl; + // reset color + set_print_reset(); } std::string input; if (prompt_iter > 0) { - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") break; - std::cout << "ASSISTANT: " << std::endl; + std::cout << "ASSISTANT: "; } if (prompt_iter == 0) { @@ -499,7 +571,7 @@ int main(int argc, char* argv[]) { int format_id = data_format_list[target_data_format]; // Voicechat instructions - if (use_voicechat){ + if (use_voicechat) { std::cout << "You are using the TinyVoiceChat." << std::endl; std::cout << "*Usage instructions*" << std::endl; std::cout << "- Please use this mode in a quiet environment to have a better user experience and avoid speech misdetection." << std::endl; @@ -536,23 +608,39 @@ int main(int argc, char* argv[]) { while (true) { std::string input; if (prompt_iter == 1) { - std::cout << "Finished!" << std::endl; + // Set prompt color + set_print_yellow(); + std::cout << "Finished!" << std::endl << std::endl; + // reset color + set_print_reset(); } if (prompt_iter > 0) { - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") break; - std::cout << "ASSISTANT: " << std::endl; + std::cout << "ASSISTANT: "; } if (prompt_iter == 0) { @@ -577,24 +665,40 @@ int main(int argc, char* argv[]) { // Get input from the user while (true) { if (prompt_iter == 1) { - std::cout << "Finished!" << std::endl; + // Set prompt color + set_print_yellow(); + std::cout << "Finished!" << std::endl << std::endl; + // reset color + set_print_reset(); } std::string input; if (prompt_iter > 0) { - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } if (input == "quit" || input == "Quit" || input == "Quit." || input == "quit.") break; - std::cout << "ASSISTANT: " << std::endl; + std::cout << "ASSISTANT: "; } if (prompt_iter == 0) { @@ -636,20 +740,32 @@ int main(int argc, char* argv[]) { generation_config.n_predict = 512; if (format_id == QINT8) { OPTForCausalLM model = OPTForCausalLM("INT8/" + m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; + std::cout << "Finished!" << std::endl << std::endl; // Get input from the user std::string input; - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); @@ -659,20 +775,32 @@ int main(int argc, char* argv[]) { OPTGenerate(&model, OPT_INT8, input_ids, generation_config, &encoder, true, use_voicechat); } else if (format_id == FP32) { Fp32OPTForCausalLM model = Fp32OPTForCausalLM(m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; + std::cout << "Finished!" << std::endl << std::endl; // Get input from the user std::string input; - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } std::vector input_ids = encoder.encode(input); std::string decoded = encoder.decode(input_ids); @@ -682,20 +810,32 @@ int main(int argc, char* argv[]) { OPTGenerate(&model, OPT_FP32, input_ids, generation_config, &encoder, true, use_voicechat); } else if (format_id == INT4) { Int4OPTForCausalLM model = Int4OPTForCausalLM("INT4/" + m_path, get_opt_model_config(model_id)); - std::cout << "Finished!" << std::endl; + std::cout << "Finished!" << std::endl << std::endl; // Get input from the user std::string input; - if (use_voicechat){ + if (use_voicechat) { + // Set prompt color + set_print_yellow(); int result = std::system("./application/sts_utils/listen"); std::ifstream in("tmpfile"); + // set user input color + set_print_red(); std::getline(in, input); result = std::system("rm tmpfile"); (void)result; std::cout << input << std::endl; + // reset color + set_print_reset(); } else { + // Set prompt color + set_print_yellow(); std::cout << "USER: "; + // set user input color + set_print_red(); std::getline(std::cin, input); + // reset color + set_print_reset(); } std::vector input_ids = encoder.encode(input); diff --git a/llm/chat-13b b/llm/chat-13b new file mode 100755 index 00000000..5e46aafd --- /dev/null +++ b/llm/chat-13b @@ -0,0 +1,2 @@ +# !/bin/bash +./chat LLaMA2_13B_chat INT4 5 diff --git a/llm/code b/llm/code new file mode 100755 index 00000000..07e336a4 --- /dev/null +++ b/llm/code @@ -0,0 +1,2 @@ +# !/bin/bash +./chat CodeLLaMA_7B_Instruct INT4 5 diff --git a/llm/include/interface.h b/llm/include/interface.h new file mode 100644 index 00000000..fb17895c --- /dev/null +++ b/llm/include/interface.h @@ -0,0 +1,12 @@ +#ifndef INTERFACE_H +#define INTERFACE_H + +void set_print_black(); +void set_print_red(); +void set_print_yellow(); +void set_print_bold_yellow(); +void set_print_blue(); +void set_print_white(); +void set_print_reset(); + +#endif diff --git a/llm/include/ops/linear.h b/llm/include/ops/linear.h index 37c3d623..bccbcebc 100644 --- a/llm/include/ops/linear.h +++ b/llm/include/ops/linear.h @@ -126,6 +126,9 @@ class Linear_FP_int4 { void forward_fast(const Matrix3D &x, Matrix3D &output); #ifdef USE_INT8_INT4_PRODUCT static void initialize_memory(const int block_size); +#endif +#ifdef QM_ARM + static void initialize_weight_memory(); #endif Matrix3D weight; Matrix3D scale, zero_point; diff --git a/llm/include/profiler.h b/llm/include/profiler.h index 2c7ff4fe..f449949d 100644 --- a/llm/include/profiler.h +++ b/llm/include/profiler.h @@ -42,7 +42,7 @@ class Profiler { std::cout << entry.first + ", "; float s = (float)(entry.second) / 1000000; float ts = (float)counts.at(entry.first); - printf("Total time: %.1f s, %.1f ms/token, %.1f token/s, %d tokens\n", s, s / ts * 1000, ts / s, + printf("Total time: %.1f s, %.1f ms/token, %.1f token/s, %d tokens\n\n", s, s / ts * 1000, ts / s, counts.at(entry.first)); } } else { diff --git a/llm/scripts/llava.sh b/llm/scripts/llava.sh index e374698d..86ff0b1d 100755 --- a/llm/scripts/llava.sh +++ b/llm/scripts/llava.sh @@ -4,4 +4,4 @@ image_path="$1" termvisage $image_path -w 75 echo "=============================================================================================================================" -./chat LLaVA_7B INT4 5 $image_path +./chat LLaVA_7B INT4 6 $image_path diff --git a/llm/scripts/voice_llava.sh b/llm/scripts/voice_llava.sh index 803be634..a4cf37cf 100755 --- a/llm/scripts/voice_llava.sh +++ b/llm/scripts/voice_llava.sh @@ -4,4 +4,4 @@ image_path="$1" termvisage $image_path -w 75 echo "=============================================================================================================================" -./chat -v LLaVA_7B INT4 5 $image_path +./chat -v LLaVA_7B INT4 6 $image_path diff --git a/llm/src/GPTBigCodeGenerate.cc b/llm/src/GPTBigCodeGenerate.cc index 65b84190..f09b63b0 100644 --- a/llm/src/GPTBigCodeGenerate.cc +++ b/llm/src/GPTBigCodeGenerate.cc @@ -1,11 +1,12 @@ +#include +#include +#include #include "Generate.h" #include "GPTBigCodeTokenizer.h" #include "common.h" #include "utils.h" -#include -#include -#include +#include "interface.h" std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config, std::string voc_path, bool interactive) { @@ -171,7 +172,12 @@ std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int mode if (interactive) std::cout << std::endl; + // Set prompt color + set_print_yellow(); Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); + // Reset color + set_print_reset(); + return output; } diff --git a/llm/src/OPTGenerate.cc b/llm/src/OPTGenerate.cc index c2e23c01..c008fad1 100644 --- a/llm/src/OPTGenerate.cc +++ b/llm/src/OPTGenerate.cc @@ -1,10 +1,12 @@ -#include "Generate.h" -#include "common.h" -#include "utils.h" #include #include #include +#include "Generate.h" +#include "common.h" +#include "utils.h" +#include "interface.h" + // Function to speak in the background void speakInBackground(const std::string& text) { std::string command = "./application/sts_utils/speak \"" + text + "\""; @@ -33,7 +35,7 @@ std::vector OPTGenerate(void *model_ptr, int model_type, std::vector i } if (encoder == NULL) interactive = false; - if (interactive) std::cout << "ASSISTANT: " << std::endl; + if (interactive) std::cout << "ASSISTANT: "; bool has_past_kv = false; std::vector> past_keys_int8, past_values_int8; @@ -222,8 +224,12 @@ std::vector OPTGenerate(void *model_ptr, int model_type, std::vector i } if (interactive) std::cout << std::endl; + // Set prompt color + set_print_yellow(); Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); + // Reset color + set_print_reset(); return generate_ids; } diff --git a/llm/src/interface.cc b/llm/src/interface.cc new file mode 100644 index 00000000..6dddf575 --- /dev/null +++ b/llm/src/interface.cc @@ -0,0 +1,30 @@ +#include "interface.h" +#include + +void set_print_black() { + printf("\033[0;30m"); +} + +void set_print_red() { + printf("\033[1;31m"); +} + +void set_print_yellow() { + printf("\033[0;33m"); +} + +void set_print_bold_yellow() { + printf("\033[1;33m"); +} + +void set_print_blue() { + printf("\033[1;34m"); +} + +void set_print_white() { + printf("\033[0;37m"); +} + +void set_print_reset() { + printf("\033[0m"); +} diff --git a/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc b/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc index 3d65c0a7..e213b3ef 100644 --- a/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc +++ b/llm/src/nn_modules/non_cuda/LLaMAGenerate.cc @@ -1,11 +1,12 @@ +#include +#include +#include #include "Generate.h" #include "LLaMATokenizer.h" #include "common.h" #include "utils.h" -#include -#include -#include +#include "interface.h" // Function to speak in the background void sayInBackground(const std::string& text) { @@ -247,7 +248,12 @@ std::string LLaMAGenerate(std::string param_path, void *model_ptr, int model_typ if (interactive) std::cout << std::endl; + // Set prompt color + set_print_yellow(); Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); + // Reset color + set_print_reset(); + return output; } diff --git a/llm/src/nn_modules/non_cuda/LLaVAGenerate.cc b/llm/src/nn_modules/non_cuda/LLaVAGenerate.cc index 7623037f..2900e183 100644 --- a/llm/src/nn_modules/non_cuda/LLaVAGenerate.cc +++ b/llm/src/nn_modules/non_cuda/LLaVAGenerate.cc @@ -1,11 +1,12 @@ +#include +#include +#include #include "Generate.h" #include "LLaMATokenizer.h" #include "common.h" #include "utils.h" -#include -#include -#include +#include "interface.h" #define STB_IMAGE_IMPLEMENTATION #include "stb_image.h" @@ -329,8 +330,13 @@ std::string LLaVAGenerate(std::string llama_param_path, void* llama_model_ptr, s } first_prompt = false; + // Set prompt color + set_print_yellow(); Profiler::getInstance().report_internal(); Profiler::getInstance().reset(); + // Reset color + set_print_reset(); + return output; } diff --git a/llm/src/ops/BMM_F32T.cc b/llm/src/ops/BMM_F32T.cc index d1eb03b5..81cdee4e 100644 --- a/llm/src/ops/BMM_F32T.cc +++ b/llm/src/ops/BMM_F32T.cc @@ -38,11 +38,15 @@ void BMM_F32T::forward(const Matrix3D &a, const Matrix3D &weight, // op.mat_mul_transposed_fastover_column((const struct matmul_params // *)¶ms); // else +#ifdef QM_ARM + op.mat_mul_accelerator_transposed_fastover_column(¶ms); +#else op.mat_mul_transposed(¶ms); // TODO: optimize this // TODO: apply SIMD here for (int i = 0; i < m * n; i++) { params.C.data_ptr[i] *= this->alpha; } +#endif params.A.data_ptr += m * k; params.B.data_ptr += k * n; params.C.data_ptr += m * n; @@ -83,11 +87,18 @@ void BMM_F32T::forward_weight_untransposed(const Matrix3D &a, const Matri for (int i = 0; i < m * n * a.m_dim_x; i++) { params.C.data_ptr[i] = 0; } - +#ifdef QM_ARM + for (int bz = 0; bz < a.m_dim_x; bz++) { + op.mat_mul_accelerator_untransposed_fastover_column(¶ms); + params.A.data_ptr += m * k; + params.B.data_ptr += k * n; + params.C.data_ptr += m * n; + } +#else for (int bz = 0; bz < a.m_dim_x; bz++) { float *data_A = params.A.data_ptr + bz * m * k, *data_B = params.B.data_ptr + bz * k * n, *data_C = params.C.data_ptr + bz * m * n; - for (int i = 0; i < m; i++) + for (int i = 0; i < m; i++) { for (int kk = 0; kk < k; kk++) { float Aikk0 = data_A[i * k + kk]; for (int j = 0; j < n; j++) { @@ -95,7 +106,9 @@ void BMM_F32T::forward_weight_untransposed(const Matrix3D &a, const Matri data_C[i * n + j] += Aikk0 * Bjk0; } } + } } +#endif PROFILE_END(profile_name); } diff --git a/llm/src/ops/linear.cc b/llm/src/ops/linear.cc index b0304c31..863b9971 100644 --- a/llm/src/ops/linear.cc +++ b/llm/src/ops/linear.cc @@ -26,6 +26,14 @@ void linear(Matrix3D &a, Matrix3D &b, Matrix3D &c) { } } +#ifdef QM_ARM +#define MAX_WEIGHT_BUFFER 32000 * 4096 +static float *w_fp32; +void Linear_FP_int4::initialize_weight_memory() { + allocate_aligned_memory(w_fp32, MAX_WEIGHT_BUFFER * sizeof(float)); +} +#endif + void Linear_FP::forward(const Matrix3D &a, Matrix3D &c) { Matrix3D b = this->weight; const int m = a.m_dim_y, n = b.m_dim_y, k = a.m_dim_z, b_size = b.m_dim_x; @@ -207,12 +215,14 @@ void Linear_FP_int4::forward(const Matrix3D &x, Matrix3D &output) else params.bias.data_ptr = this->bias.m_data; #endif -#ifdef QM_ARM - if (params.A.row == 1) { - op.gemv_accelerator_int8_int4_fast_no_offset(¶ms); - } else { +#ifdef USE_ACCELERATE + if (!w_fp32) this->initialize_weight_memory(); + params.B.data_ptr = w_fp32; + if (params.A.row <= 100) { op.mat_mul_accelerator_int8_int4_fast_no_offset(¶ms); - // op.gemm_accelerator_int8_int4_fast_no_offset(¶ms); + } else { + params.alpha = 1.0; + op.cblas_gemm_accelerator_no_offset(¶ms); } #else op.mat_mul_accelerator_int8_int4_fast_no_offset(¶ms); diff --git a/llm/vila b/llm/vila new file mode 100755 index 00000000..c8cd6d60 --- /dev/null +++ b/llm/vila @@ -0,0 +1,7 @@ +# !/bin/bash +echo "=============================================================================================================================" +image_path="$1" +termvisage $image_path -w 75 +echo "=============================================================================================================================" + +./chat VILA_7B INT4 5 $image_path diff --git a/llm/voice_vila b/llm/voice_vila new file mode 100755 index 00000000..2558c19e --- /dev/null +++ b/llm/voice_vila @@ -0,0 +1,7 @@ +# !/bin/bash +echo "=============================================================================================================================" +image_path="$1" +termvisage $image_path -w 75 +echo "=============================================================================================================================" + +./chat -v VILA_7B INT4 5 $image_path