diff --git a/kernels/starter_code/all_techniques.cc b/kernels/starter_code/all_techniques.cc index c902e7e..1c9db68 100644 --- a/kernels/starter_code/all_techniques.cc +++ b/kernels/starter_code/all_techniques.cc @@ -1,7 +1,6 @@ #include #include #include - #include #include @@ -14,215 +13,109 @@ #ifdef QM_x86 #include #endif + +#define CACHE_BLOCK_SIZE 32 // Cache block size for blocking optimization + struct w4a8_thread_args { int start_j, end_j; const struct matmul_params *params; }; + static void *all_techniques_worker_func(void *args) { struct w4a8_thread_args *mat_args = (struct w4a8_thread_args *)args; const struct matmul_params *params = mat_args->params; const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; - int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size; - const int num_block = k / block_size; // block_size = 32 + int n = params->C.column, m = params->C.row, k = params->A.column; + int block_size = params->block_size; - for (int row = 0; row < m; row++) { - for (int col = mat_args->start_j; col < mat_args->end_j; col++) { -#ifdef QM_ARM - // order of weights with QM_ARM: - // origin order: (w0,w1), (w2,w3), (w4,w5), (w6,w7), (w8, w9), ... (w30,w31) - // QM_ARM order: (w0,w16),(w1,w17),(w2,w18),(w3,w19),(w4, w20),... (w15,w31) - // |--| - // 4 bits - // |------| - // 8 bits (byte) - // low|----------------------------------------------------------|high - // 0 128 bit 127 - float32x4_t sumv0 = vdupq_n_f32(0.0f); - float32x4_t sumv1 = vdupq_n_f32(0.0f); - float32x4_t sumv2 = vdupq_n_f32(0.0f); - float32x4_t sumv3 = vdupq_n_f32(0.0f); - // pointer of the int4 weights - const unsigned char *w_start = ¶ms->B.int4_data_ptr[col * k / 2]; - // pointer of the int8 activation - const signed char *a_start = ¶ms->A.int8_data_ptr[row * k]; - // scale of activation - float *s_a = ¶ms->A_scales[row * k / 32]; - // scale of weight - float *s_w = ¶ms->scales[col * k / 32]; - - // process four blocks each iteration - for (int q = 0; q < num_block; q += 4) { - // load 32x4bit (16 bytes) weight - const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight - const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight - const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight - const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight - w_start += 64; - - // TODO: decode each uint8x16_t weight vector into the lower and upper half of the weights as int8x16_t - // Hint: - // (1) use `vandq_u8` with the mask_low4bit to get the lower half - // (2) use `vshrq_n_u8` to right shift 4 bits and get the upper half - // (3) use `vreinterpretq_s8_u8` to interpret the vector as int8 - // lowbit mask - const uint8x16_t mask_low4bit = vdupq_n_u8(0xf); - - // TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7) - // Hint: using `vsubq_s8` to the lower-half and upper-half vectors of weights - const int8x16_t offsets = vdupq_n_s8(8); - - // load 128 8-bit activation - const int8x16_t a0 = vld1q_s8(a_start); - const int8x16_t a1 = vld1q_s8(a_start + 16); - const int8x16_t a2 = vld1q_s8(a_start + 32); - const int8x16_t a3 = vld1q_s8(a_start + 48); - const int8x16_t a4 = vld1q_s8(a_start + 64); - const int8x16_t a5 = vld1q_s8(a_start + 80); - const int8x16_t a6 = vld1q_s8(a_start + 96); - const int8x16_t a7 = vld1q_s8(a_start + 112); - a_start += 128; - - // TODO: perform dot product and store the result into the intermediate sum, int_sum0 - // Hint: use `vdotq_s32` and store the sum for each block in int_sum{0-3} - int32x4_t int_sum0, int_sum1, int_sum2, int_sum3; - - float s_0 = *s_a++ * *s_w++; - float s_1 = *s_a++ * *s_w++; - float s_2 = *s_a++ * *s_w++; - float s_3 = *s_a++ * *s_w++; - - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum1), s_1); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum2), s_2); - sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum3), s_3); - } - params->C.data_ptr[row * n + col] = vaddvq_f32(sumv0); -#endif #ifdef QM_x86 - // order of weights with QM_x86: - // origin order: (w0,w1), (w2,w3), (w4,w5), (w6,w7), (w8, w9), ... (w62,w63) - // QM_ARM order: (w0,w32),(w1,w33),(w2,w34),(w3,w35),(w4, w36),... (w31,w63) - // |--| - // 4 bits - // |------| - // 8 bits (byte) - // low|----------------------------------------------------------|high - // 0 256 bit - __m256 accumulator = _mm256_setzero_ps(); - float *s_ptr = ¶ms->scales[col * k / 32]; - float *sa_ptr = ¶ms->A_scales[row * k / 32]; - const __m256i *w_start = (__m256i *)&B->int4_data_ptr[col * k / 2]; - const __m256i *a_start = (__m256i *)&A->int8_data_ptr[row * k]; - const int num_block = k / block_size; - // Compute four blocks = 128 4-bit weights in each iteration - for (int q = 0; q < num_block; q += 4) { - // lowbit mask - const __m256i lowMask = _mm256_set1_epi8(0xF); - - // TODO: Unpack 128 4-bit (two __mm256i) weights into 128 8-bit (four __mm256i) - // (1) load 256 bit from w_strat with _mm256_loadu_si256 - // (2) use _mm256_and_si256 and lowMask to extract the lower half of wegihts - // (3) use _mm256_srli_epi16 and _mm256_and_si256 with lowMask to extract the upper half of weights - __m256i raw_w = _mm256_loadu_si256(w_start); - __m256i raw_w_next = _mm256_loadu_si256(w_start + 1); - - // TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7) - // Hint: using `_mm256_sub_epi8` to the lower-half and upper-half vectors of weights - // Note: For the first two blocks, store the lower half and upper half of weights into `w_0` and - // `w_128`, respectively For the last two blocks store the lower half and upper half of weights into - // `w_0_next` and `w_128_next`, respectively - const __m256i zero_point = _mm256_set1_epi8(8); - __m256i w_0, w_128, w_0_next, w_128_next; - - // Perform int8 dot product with _mm256_maddubs_epi16 - /* Syntax of _mm256_maddubs_epi16: - __m256i _mm256_maddubs_epi16(__m256i s1, __m256i s2): Multiplies vertically each unsigned byte of - source vector s1 with the corresponding signed byte of source vector s2, producing intermediate, - signed 16-bit integers. Each adjacent pair of signed words is added, and the saturated result is - packed to the destination vector. - */ - // To utilize _mm256_maddubs_epi16 which only takes unsigned s1, we need to: - // (1) Get the absolute values of weights (for both lower and upper halves) - // (2) Change the sign of activation (a0-a31 and a32-a63) depending on the sign of corresponding weights - // (stored as another variable) (3) Perform dot product with _mm256_maddubs_epi16 and store the lower - // and upper halves sum in `dot` and `dot2` - __m256i dot, dot2, dot3, dot4; - // Get absolute values of x vectors - const __m256i ax = _mm256_sign_epi8(w_0, w_0); - const __m256i ax_next = _mm256_sign_epi8(w_0_next, w_0_next); - const __m256i ax2 = _mm256_sign_epi8(w_128, w_128); - const __m256i ax2_next = _mm256_sign_epi8(w_128_next, w_128_next); - // Load activation - __m256i activation = a_start[0]; - __m256i activation2 = a_start[1]; - __m256i activation_next = a_start[2]; - __m256i activation2_next = a_start[3]; - // Sign the values of the y vectors - const __m256i sy = _mm256_sign_epi8(activation, w_0); - const __m256i sy_next = _mm256_sign_epi8(activation_next, w_0_next); - const __m256i sy2 = _mm256_sign_epi8(activation2, w_128); - const __m256i sy2_next = _mm256_sign_epi8(activation2_next, w_128_next); - - // TODO: Perform int8 dot product with `_mm256_maddubs_epi16` - // Hint: use `_mm256_maddubs_epi16` to complete the following computation - // dot = ax * sy - // dot2 = ax2 * sy2 - // dot3 = ax_next * sy_next - // dot4 = ax2_next * sy2_next - - // Convert int32 vectors to floating point vectors - const __m256i ones = _mm256_set1_epi16(1); - const __m256i summed_pairs = _mm256_madd_epi16(ones, dot); - const __m256i summed_pairs2 = _mm256_madd_epi16(ones, dot2); - const __m256i summed_pairs3 = _mm256_madd_epi16(ones, dot3); - const __m256i summed_pairs4 = _mm256_madd_epi16(ones, dot4); - __m256 intermediate = _mm256_cvtepi32_ps(summed_pairs); - __m256 intermediate2 = _mm256_cvtepi32_ps(summed_pairs2); - __m256 intermediate3 = _mm256_cvtepi32_ps(summed_pairs3); - __m256 intermediate4 = _mm256_cvtepi32_ps(summed_pairs4); - - // Create vectors for scales and apply them to intermediate results - __m256 v_s = _mm256_set1_ps(s_ptr[0] * sa_ptr[0]); - __m256 v_s2 = _mm256_set1_ps(s_ptr[1] * sa_ptr[1]); - __m256 v_s3 = _mm256_set1_ps(s_ptr[2] * sa_ptr[2]); - __m256 v_s4 = _mm256_set1_ps(s_ptr[3] * sa_ptr[3]); - accumulator = _mm256_fmadd_ps(intermediate, v_s, accumulator); - accumulator = _mm256_fmadd_ps(intermediate2, v_s2, accumulator); - accumulator = _mm256_fmadd_ps(intermediate3, v_s3, accumulator); - accumulator = _mm256_fmadd_ps(intermediate4, v_s4, accumulator); - s_ptr += 4; - sa_ptr += 4; - w_start += 2; - a_start += 4; + // Cache blocking technique integrated + for (int row_block = 0; row_block < m; row_block += CACHE_BLOCK_SIZE) { + for (int col_block = mat_args->start_j; col_block < mat_args->end_j; col_block += CACHE_BLOCK_SIZE) { + for (int row = row_block; row < row_block + CACHE_BLOCK_SIZE && row < m; row++) { + for (int col = col_block; col < col_block + CACHE_BLOCK_SIZE && col < n; col++) { + __m256 accumulator = _mm256_setzero_ps(); + float *s_ptr = ¶ms->scales[col * k / 32]; + float *sa_ptr = ¶ms->A_scales[row * k / 32]; + const __m256i *w_start = (__m256i *)&B->int4_data_ptr[col * k / 2]; + const __m256i *a_start = (__m256i *)&A->int8_data_ptr[row * k]; + const int num_block = k / block_size; + + for (int q = 0; q < num_block; q += 4) { + const __m256i lowMask = _mm256_set1_epi8(0xF); + __m256i raw_w = _mm256_loadu_si256(w_start); + __m256i raw_w_next = _mm256_loadu_si256(w_start + 1); + + __m256i w_low = _mm256_and_si256(raw_w, lowMask); + __m256i w_high = _mm256_srli_epi16(raw_w, 4); + w_high = _mm256_and_si256(w_high, lowMask); + + __m256i w_low_next = _mm256_and_si256(raw_w_next, lowMask); + __m256i w_high_next = _mm256_srli_epi16(raw_w_next, 4); + w_high_next = _mm256_and_si256(w_high_next, lowMask); + + const __m256i zero_point = _mm256_set1_epi8(8); + __m256i w_0 = _mm256_sub_epi8(w_low, zero_point); + __m256i w_128 = _mm256_sub_epi8(w_high, zero_point); + __m256i w_0_next = _mm256_sub_epi8(w_low_next, zero_point); + __m256i w_128_next = _mm256_sub_epi8(w_high_next, zero_point); + + __m256i dot = _mm256_maddubs_epi16(_mm256_sign_epi8(w_0, w_0), _mm256_sign_epi8(a_start[0], w_0)); + __m256i dot2 = _mm256_maddubs_epi16(_mm256_sign_epi8(w_128, w_128), _mm256_sign_epi8(a_start[1], w_128)); + __m256i dot3 = _mm256_maddubs_epi16(_mm256_sign_epi8(w_0_next, w_0_next), _mm256_sign_epi8(a_start[2], w_0_next)); + __m256i dot4 = _mm256_maddubs_epi16(_mm256_sign_epi8(w_128_next, w_128_next), _mm256_sign_epi8(a_start[3], w_128_next)); + + const __m256i ones = _mm256_set1_epi16(1); + __m256 intermediate = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot)); + __m256 intermediate2 = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot2)); + __m256 intermediate3 = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot3)); + __m256 intermediate4 = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot4)); + + __m256 v_s = _mm256_set1_ps(s_ptr[0] * sa_ptr[0]); + __m256 v_s2 = _mm256_set1_ps(s_ptr[1] * sa_ptr[1]); + __m256 v_s3 = _mm256_set1_ps(s_ptr[2] * sa_ptr[2]); + __m256 v_s4 = _mm256_set1_ps(s_ptr[3] * sa_ptr[3]); + + accumulator = _mm256_fmadd_ps(intermediate, v_s, accumulator); + accumulator = _mm256_fmadd_ps(intermediate2, v_s2, accumulator); + accumulator = _mm256_fmadd_ps(intermediate3, v_s3, accumulator); + accumulator = _mm256_fmadd_ps(intermediate4, v_s4, accumulator); + + s_ptr += 4; + sa_ptr += 4; + w_start += 2; + a_start += 4; + } + float *ptr = (float *)&accumulator; + C->data_ptr[row * n + col] = ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7]; + } } - float *ptr = (float *)&accumulator; - C->data_ptr[row * n + col] = ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7]; -#endif } } - +#endif return NULL; } namespace matmul { void MatmulOperator::mat_mul_all_techniques(struct matmul_params *params) { - int i, j, k; const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; - const int block_size = params->block_size; - float *scale = params->scales, *offset = params->offset; - - assert(params->block_size % 32 == 0); // support block size to be multiples of 32 - assert(A->row == C->row); // support block size to be multiples of 32 + assert(params->block_size == 32); // Ensure block size is 32 - quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size); + quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, params->block_size); const int num_thread = 8; pthread_t thread_pool[num_thread]; struct w4a8_thread_args threads_args[num_thread]; - assert(params->block_size == 32); // support block size 32 for now + int cols_per_thread = C->column / num_thread; - // TODO: Thread creation - - // TODO: Join threads + for (int i = 0; i < num_thread; i++) { + threads_args[i].start_j = i * cols_per_thread; + threads_args[i].end_j = (i == num_thread - 1) ? C->column : (i + 1) * cols_per_thread; + threads_args[i].params = params; + pthread_create(&thread_pool[i], NULL, all_techniques_worker_func, &threads_args[i]); + } + for (int i = 0; i < num_thread; i++) { + pthread_join(thread_pool[i], NULL); + } }; } // namespace matmul diff --git a/kernels/starter_code/cache_blocking.cc b/kernels/starter_code/cache_blocking.cc new file mode 100644 index 0000000..4e81b20 --- /dev/null +++ b/kernels/starter_code/cache_blocking.cc @@ -0,0 +1,75 @@ +#include +#include +#include +#include +#include +#include +#include +#include "../matmul.h" +#include "common.h" + +namespace matmul { +void MatmulOperator::mat_mul_cache_blocking(struct matmul_params *params) { + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + const int block_size = params->block_size; + + quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size); + + const int m = C->row; + const int n = C->column; + const int k = A->column; + + // Defining block sizes for each dimension + const int BM = 32; // Block size for rows + const int BN = 32; // Block size for columns + const int BK = 32; // Block size for reduction dimension + + // Iterate over blocks + for (int i0 = 0; i0 < m; i0 += BM) { + for (int j0 = 0; j0 < n; j0 += BN) { + // Initialize accumulator block + const int imax = std::min(i0 + BM, m); + const int jmax = std::min(j0 + BN, n); + + for (int i = i0; i < imax; i++) { + for (int j = j0; j < jmax; j++) { + C->data_ptr[i * n + j] = 0; + } + } + + // Process blocks along k dimension + for (int k0 = 0; k0 < k;) { +#ifdef QM_x86 + for (int i = i0; i < imax; i++) { + for (int j = j0; j < jmax; j++) { + // Get scales for the current block + float s_w = params->scales[(j * k + k0) / block_size]; + float s_a = params->A_scales[(i * k + k0) / block_size]; + float s_w_2nd = params->scales[(j * k + k0) / block_size + 1]; + float s_a_2nd = params->A_scales[(i * k + k0) / block_size + 1]; + + // Process the block + uint8_t *w_int4 = &B->int4_data_ptr[(j * k + k0) / 2]; + const signed char *a_int8 = &A->int8_data_ptr[i * k + k0]; + + int intermediate_sum = 0, intermediate_sum_2nd = 0; + for (int qj = 0; qj < 32; qj++) { + uint8_t packed_int4_0 = w_int4[qj]; + signed char w_de_0 = (packed_int4_0 & 0x0F) - 8; + signed char w_de_32 = (packed_int4_0 >> 4) - 8; + + intermediate_sum += a_int8[qj] * w_de_0; + intermediate_sum_2nd += a_int8[qj + 32] * w_de_32; + } + + C->data_ptr[i * n + j] += (float)intermediate_sum * s_a * s_w; + C->data_ptr[i * n + j] += (float)intermediate_sum_2nd * s_a_2nd * s_w_2nd; + } + } + k0 += block_size * 2; +#endif + } + } + } +} +} // namespace matmul \ No newline at end of file diff --git a/transformer/Makefile b/transformer/Makefile index 5625e32..f330880 100644 --- a/transformer/Makefile +++ b/transformer/Makefile @@ -1,6 +1,6 @@ # Compiler and flags CXX = g++ -CXXFLAGS = -std=c++11 -pthread -g -O0 -w +CXXFLAGS = -std=c++11 -pthread -Ofast -w CXXFLAGS += -DIMP=$(IMP) # Executable and source files diff --git a/transformer/evaluate.sh b/transformer/evaluate.sh index 551732b..debf124 100755 --- a/transformer/evaluate.sh +++ b/transformer/evaluate.sh @@ -7,8 +7,9 @@ # 3: simd_programming # 4: multithreading_loop_unrolling # 5: all_techniques -keys=("reference" "loop_unrolling" "multithreading" "simd_programming" "multithreading_loop_unrolling" "all_techniques") -values=("0" "1" "2" "3" "4" "5") +# 6: cache_blocking +keys=("reference" "loop_unrolling" "multithreading" "simd_programming" "multithreading_loop_unrolling" "all_techniques" "cache_blocking") +values=("0" "1" "2" "3" "4" "5" "6") # If a implementation is provided to the script, map it to the corresponding argument if [ "$#" -eq 1 ]; then diff --git a/transformer/include/ops/linear.h b/transformer/include/ops/linear.h index 3d0ae5d..11daa8b 100644 --- a/transformer/include/ops/linear.h +++ b/transformer/include/ops/linear.h @@ -56,6 +56,8 @@ class Linear_FP_int4 { std::string profile_name = "multithreading_loop_unrolling"; #elif IMP == 5 std::string profile_name = "all_techniques"; +#elif IMP == 6 + std::string profile_name = "cache_blocking"; #else std::string profile_name = "Unkown"; #endif