Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bonus Performance Optimization: Cache Blocking Technique #10

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
265 changes: 79 additions & 186 deletions kernels/starter_code/all_techniques.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include <assert.h>
#include <pthread.h>
#include <stdio.h>

#include <cmath>
#include <cstdlib>

Expand All @@ -14,215 +13,109 @@
#ifdef QM_x86
#include <immintrin.h>
#endif

#define CACHE_BLOCK_SIZE 32 // Cache block size for blocking optimization

struct w4a8_thread_args {
int start_j, end_j;
const struct matmul_params *params;
};

static void *all_techniques_worker_func(void *args) {
struct w4a8_thread_args *mat_args = (struct w4a8_thread_args *)args;
const struct matmul_params *params = mat_args->params;
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
int n = params->C.column, m = params->C.row, k = params->A.column, block_size = params->block_size;
const int num_block = k / block_size; // block_size = 32
int n = params->C.column, m = params->C.row, k = params->A.column;
int block_size = params->block_size;

for (int row = 0; row < m; row++) {
for (int col = mat_args->start_j; col < mat_args->end_j; col++) {
#ifdef QM_ARM
// order of weights with QM_ARM:
// origin order: (w0,w1), (w2,w3), (w4,w5), (w6,w7), (w8, w9), ... (w30,w31)
// QM_ARM order: (w0,w16),(w1,w17),(w2,w18),(w3,w19),(w4, w20),... (w15,w31)
// |--|
// 4 bits
// |------|
// 8 bits (byte)
// low|----------------------------------------------------------|high
// 0 128 bit 127
float32x4_t sumv0 = vdupq_n_f32(0.0f);
float32x4_t sumv1 = vdupq_n_f32(0.0f);
float32x4_t sumv2 = vdupq_n_f32(0.0f);
float32x4_t sumv3 = vdupq_n_f32(0.0f);
// pointer of the int4 weights
const unsigned char *w_start = &params->B.int4_data_ptr[col * k / 2];
// pointer of the int8 activation
const signed char *a_start = &params->A.int8_data_ptr[row * k];
// scale of activation
float *s_a = &params->A_scales[row * k / 32];
// scale of weight
float *s_w = &params->scales[col * k / 32];

// process four blocks each iteration
for (int q = 0; q < num_block; q += 4) {
// load 32x4bit (16 bytes) weight
const uint8x16_t w0 = vld1q_u8(w_start); // 32 4bit weight
const uint8x16_t w1 = vld1q_u8(w_start + 16); // 32 4bit weight
const uint8x16_t w2 = vld1q_u8(w_start + 32); // 32 4bit weight
const uint8x16_t w3 = vld1q_u8(w_start + 48); // 32 4bit weight
w_start += 64;

// TODO: decode each uint8x16_t weight vector into the lower and upper half of the weights as int8x16_t
// Hint:
// (1) use `vandq_u8` with the mask_low4bit to get the lower half
// (2) use `vshrq_n_u8` to right shift 4 bits and get the upper half
// (3) use `vreinterpretq_s8_u8` to interpret the vector as int8
// lowbit mask
const uint8x16_t mask_low4bit = vdupq_n_u8(0xf);

// TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7)
// Hint: using `vsubq_s8` to the lower-half and upper-half vectors of weights
const int8x16_t offsets = vdupq_n_s8(8);

// load 128 8-bit activation
const int8x16_t a0 = vld1q_s8(a_start);
const int8x16_t a1 = vld1q_s8(a_start + 16);
const int8x16_t a2 = vld1q_s8(a_start + 32);
const int8x16_t a3 = vld1q_s8(a_start + 48);
const int8x16_t a4 = vld1q_s8(a_start + 64);
const int8x16_t a5 = vld1q_s8(a_start + 80);
const int8x16_t a6 = vld1q_s8(a_start + 96);
const int8x16_t a7 = vld1q_s8(a_start + 112);
a_start += 128;

// TODO: perform dot product and store the result into the intermediate sum, int_sum0
// Hint: use `vdotq_s32` and store the sum for each block in int_sum{0-3}
int32x4_t int_sum0, int_sum1, int_sum2, int_sum3;

float s_0 = *s_a++ * *s_w++;
float s_1 = *s_a++ * *s_w++;
float s_2 = *s_a++ * *s_w++;
float s_3 = *s_a++ * *s_w++;

sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum0), s_0);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum1), s_1);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum2), s_2);
sumv0 = vmlaq_n_f32(sumv0, vcvtq_f32_s32(int_sum3), s_3);
}
params->C.data_ptr[row * n + col] = vaddvq_f32(sumv0);
#endif
#ifdef QM_x86
// order of weights with QM_x86:
// origin order: (w0,w1), (w2,w3), (w4,w5), (w6,w7), (w8, w9), ... (w62,w63)
// QM_ARM order: (w0,w32),(w1,w33),(w2,w34),(w3,w35),(w4, w36),... (w31,w63)
// |--|
// 4 bits
// |------|
// 8 bits (byte)
// low|----------------------------------------------------------|high
// 0 256 bit
__m256 accumulator = _mm256_setzero_ps();
float *s_ptr = &params->scales[col * k / 32];
float *sa_ptr = &params->A_scales[row * k / 32];
const __m256i *w_start = (__m256i *)&B->int4_data_ptr[col * k / 2];
const __m256i *a_start = (__m256i *)&A->int8_data_ptr[row * k];
const int num_block = k / block_size;
// Compute four blocks = 128 4-bit weights in each iteration
for (int q = 0; q < num_block; q += 4) {
// lowbit mask
const __m256i lowMask = _mm256_set1_epi8(0xF);

// TODO: Unpack 128 4-bit (two __mm256i) weights into 128 8-bit (four __mm256i)
// (1) load 256 bit from w_strat with _mm256_loadu_si256
// (2) use _mm256_and_si256 and lowMask to extract the lower half of wegihts
// (3) use _mm256_srli_epi16 and _mm256_and_si256 with lowMask to extract the upper half of weights
__m256i raw_w = _mm256_loadu_si256(w_start);
__m256i raw_w_next = _mm256_loadu_si256(w_start + 1);

// TODO: apply zero_point to weights and convert the range from (0, 15) to (-8, 7)
// Hint: using `_mm256_sub_epi8` to the lower-half and upper-half vectors of weights
// Note: For the first two blocks, store the lower half and upper half of weights into `w_0` and
// `w_128`, respectively For the last two blocks store the lower half and upper half of weights into
// `w_0_next` and `w_128_next`, respectively
const __m256i zero_point = _mm256_set1_epi8(8);
__m256i w_0, w_128, w_0_next, w_128_next;

// Perform int8 dot product with _mm256_maddubs_epi16
/* Syntax of _mm256_maddubs_epi16:
__m256i _mm256_maddubs_epi16(__m256i s1, __m256i s2): Multiplies vertically each unsigned byte of
source vector s1 with the corresponding signed byte of source vector s2, producing intermediate,
signed 16-bit integers. Each adjacent pair of signed words is added, and the saturated result is
packed to the destination vector.
*/
// To utilize _mm256_maddubs_epi16 which only takes unsigned s1, we need to:
// (1) Get the absolute values of weights (for both lower and upper halves)
// (2) Change the sign of activation (a0-a31 and a32-a63) depending on the sign of corresponding weights
// (stored as another variable) (3) Perform dot product with _mm256_maddubs_epi16 and store the lower
// and upper halves sum in `dot` and `dot2`
__m256i dot, dot2, dot3, dot4;
// Get absolute values of x vectors
const __m256i ax = _mm256_sign_epi8(w_0, w_0);
const __m256i ax_next = _mm256_sign_epi8(w_0_next, w_0_next);
const __m256i ax2 = _mm256_sign_epi8(w_128, w_128);
const __m256i ax2_next = _mm256_sign_epi8(w_128_next, w_128_next);
// Load activation
__m256i activation = a_start[0];
__m256i activation2 = a_start[1];
__m256i activation_next = a_start[2];
__m256i activation2_next = a_start[3];
// Sign the values of the y vectors
const __m256i sy = _mm256_sign_epi8(activation, w_0);
const __m256i sy_next = _mm256_sign_epi8(activation_next, w_0_next);
const __m256i sy2 = _mm256_sign_epi8(activation2, w_128);
const __m256i sy2_next = _mm256_sign_epi8(activation2_next, w_128_next);

// TODO: Perform int8 dot product with `_mm256_maddubs_epi16`
// Hint: use `_mm256_maddubs_epi16` to complete the following computation
// dot = ax * sy
// dot2 = ax2 * sy2
// dot3 = ax_next * sy_next
// dot4 = ax2_next * sy2_next

// Convert int32 vectors to floating point vectors
const __m256i ones = _mm256_set1_epi16(1);
const __m256i summed_pairs = _mm256_madd_epi16(ones, dot);
const __m256i summed_pairs2 = _mm256_madd_epi16(ones, dot2);
const __m256i summed_pairs3 = _mm256_madd_epi16(ones, dot3);
const __m256i summed_pairs4 = _mm256_madd_epi16(ones, dot4);
__m256 intermediate = _mm256_cvtepi32_ps(summed_pairs);
__m256 intermediate2 = _mm256_cvtepi32_ps(summed_pairs2);
__m256 intermediate3 = _mm256_cvtepi32_ps(summed_pairs3);
__m256 intermediate4 = _mm256_cvtepi32_ps(summed_pairs4);

// Create vectors for scales and apply them to intermediate results
__m256 v_s = _mm256_set1_ps(s_ptr[0] * sa_ptr[0]);
__m256 v_s2 = _mm256_set1_ps(s_ptr[1] * sa_ptr[1]);
__m256 v_s3 = _mm256_set1_ps(s_ptr[2] * sa_ptr[2]);
__m256 v_s4 = _mm256_set1_ps(s_ptr[3] * sa_ptr[3]);
accumulator = _mm256_fmadd_ps(intermediate, v_s, accumulator);
accumulator = _mm256_fmadd_ps(intermediate2, v_s2, accumulator);
accumulator = _mm256_fmadd_ps(intermediate3, v_s3, accumulator);
accumulator = _mm256_fmadd_ps(intermediate4, v_s4, accumulator);
s_ptr += 4;
sa_ptr += 4;
w_start += 2;
a_start += 4;
// Cache blocking technique integrated
for (int row_block = 0; row_block < m; row_block += CACHE_BLOCK_SIZE) {
for (int col_block = mat_args->start_j; col_block < mat_args->end_j; col_block += CACHE_BLOCK_SIZE) {
for (int row = row_block; row < row_block + CACHE_BLOCK_SIZE && row < m; row++) {
for (int col = col_block; col < col_block + CACHE_BLOCK_SIZE && col < n; col++) {
__m256 accumulator = _mm256_setzero_ps();
float *s_ptr = &params->scales[col * k / 32];
float *sa_ptr = &params->A_scales[row * k / 32];
const __m256i *w_start = (__m256i *)&B->int4_data_ptr[col * k / 2];
const __m256i *a_start = (__m256i *)&A->int8_data_ptr[row * k];
const int num_block = k / block_size;

for (int q = 0; q < num_block; q += 4) {
const __m256i lowMask = _mm256_set1_epi8(0xF);
__m256i raw_w = _mm256_loadu_si256(w_start);
__m256i raw_w_next = _mm256_loadu_si256(w_start + 1);

__m256i w_low = _mm256_and_si256(raw_w, lowMask);
__m256i w_high = _mm256_srli_epi16(raw_w, 4);
w_high = _mm256_and_si256(w_high, lowMask);

__m256i w_low_next = _mm256_and_si256(raw_w_next, lowMask);
__m256i w_high_next = _mm256_srli_epi16(raw_w_next, 4);
w_high_next = _mm256_and_si256(w_high_next, lowMask);

const __m256i zero_point = _mm256_set1_epi8(8);
__m256i w_0 = _mm256_sub_epi8(w_low, zero_point);
__m256i w_128 = _mm256_sub_epi8(w_high, zero_point);
__m256i w_0_next = _mm256_sub_epi8(w_low_next, zero_point);
__m256i w_128_next = _mm256_sub_epi8(w_high_next, zero_point);

__m256i dot = _mm256_maddubs_epi16(_mm256_sign_epi8(w_0, w_0), _mm256_sign_epi8(a_start[0], w_0));
__m256i dot2 = _mm256_maddubs_epi16(_mm256_sign_epi8(w_128, w_128), _mm256_sign_epi8(a_start[1], w_128));
__m256i dot3 = _mm256_maddubs_epi16(_mm256_sign_epi8(w_0_next, w_0_next), _mm256_sign_epi8(a_start[2], w_0_next));
__m256i dot4 = _mm256_maddubs_epi16(_mm256_sign_epi8(w_128_next, w_128_next), _mm256_sign_epi8(a_start[3], w_128_next));

const __m256i ones = _mm256_set1_epi16(1);
__m256 intermediate = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot));
__m256 intermediate2 = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot2));
__m256 intermediate3 = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot3));
__m256 intermediate4 = _mm256_cvtepi32_ps(_mm256_madd_epi16(ones, dot4));

__m256 v_s = _mm256_set1_ps(s_ptr[0] * sa_ptr[0]);
__m256 v_s2 = _mm256_set1_ps(s_ptr[1] * sa_ptr[1]);
__m256 v_s3 = _mm256_set1_ps(s_ptr[2] * sa_ptr[2]);
__m256 v_s4 = _mm256_set1_ps(s_ptr[3] * sa_ptr[3]);

accumulator = _mm256_fmadd_ps(intermediate, v_s, accumulator);
accumulator = _mm256_fmadd_ps(intermediate2, v_s2, accumulator);
accumulator = _mm256_fmadd_ps(intermediate3, v_s3, accumulator);
accumulator = _mm256_fmadd_ps(intermediate4, v_s4, accumulator);

s_ptr += 4;
sa_ptr += 4;
w_start += 2;
a_start += 4;
}
float *ptr = (float *)&accumulator;
C->data_ptr[row * n + col] = ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7];
}
}
float *ptr = (float *)&accumulator;
C->data_ptr[row * n + col] = ptr[0] + ptr[1] + ptr[2] + ptr[3] + ptr[4] + ptr[5] + ptr[6] + ptr[7];
#endif
}
}

#endif
return NULL;
}

namespace matmul {
void MatmulOperator::mat_mul_all_techniques(struct matmul_params *params) {
int i, j, k;
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
const int block_size = params->block_size;
float *scale = params->scales, *offset = params->offset;

assert(params->block_size % 32 == 0); // support block size to be multiples of 32
assert(A->row == C->row); // support block size to be multiples of 32
assert(params->block_size == 32); // Ensure block size is 32

quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size);
quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, params->block_size);

const int num_thread = 8;
pthread_t thread_pool[num_thread];
struct w4a8_thread_args threads_args[num_thread];
assert(params->block_size == 32); // support block size 32 for now
int cols_per_thread = C->column / num_thread;

// TODO: Thread creation

// TODO: Join threads
for (int i = 0; i < num_thread; i++) {
threads_args[i].start_j = i * cols_per_thread;
threads_args[i].end_j = (i == num_thread - 1) ? C->column : (i + 1) * cols_per_thread;
threads_args[i].params = params;
pthread_create(&thread_pool[i], NULL, all_techniques_worker_func, &threads_args[i]);
}
for (int i = 0; i < num_thread; i++) {
pthread_join(thread_pool[i], NULL);
}
};
} // namespace matmul
75 changes: 75 additions & 0 deletions kernels/starter_code/cache_blocking.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include <assert.h>
#include <pthread.h>
#include <stdio.h>
#include <cmath>
#include <cstdlib>
#include <algorithm>
#include <cstdint>
#include "../matmul.h"
#include "common.h"

namespace matmul {
void MatmulOperator::mat_mul_cache_blocking(struct matmul_params *params) {
const struct matrix *A = &params->A, *B = &params->B, *C = &params->C;
const int block_size = params->block_size;

quantize_fp32_to_int8(A->data_ptr, A->int8_data_ptr, params->A_scales, A->row * A->column, block_size);

const int m = C->row;
const int n = C->column;
const int k = A->column;

// Defining block sizes for each dimension
const int BM = 32; // Block size for rows
const int BN = 32; // Block size for columns
const int BK = 32; // Block size for reduction dimension

// Iterate over blocks
for (int i0 = 0; i0 < m; i0 += BM) {
for (int j0 = 0; j0 < n; j0 += BN) {
// Initialize accumulator block
const int imax = std::min(i0 + BM, m);
const int jmax = std::min(j0 + BN, n);

for (int i = i0; i < imax; i++) {
for (int j = j0; j < jmax; j++) {
C->data_ptr[i * n + j] = 0;
}
}

// Process blocks along k dimension
for (int k0 = 0; k0 < k;) {
#ifdef QM_x86
for (int i = i0; i < imax; i++) {
for (int j = j0; j < jmax; j++) {
// Get scales for the current block
float s_w = params->scales[(j * k + k0) / block_size];
float s_a = params->A_scales[(i * k + k0) / block_size];
float s_w_2nd = params->scales[(j * k + k0) / block_size + 1];
float s_a_2nd = params->A_scales[(i * k + k0) / block_size + 1];

// Process the block
uint8_t *w_int4 = &B->int4_data_ptr[(j * k + k0) / 2];
const signed char *a_int8 = &A->int8_data_ptr[i * k + k0];

int intermediate_sum = 0, intermediate_sum_2nd = 0;
for (int qj = 0; qj < 32; qj++) {
uint8_t packed_int4_0 = w_int4[qj];
signed char w_de_0 = (packed_int4_0 & 0x0F) - 8;
signed char w_de_32 = (packed_int4_0 >> 4) - 8;

intermediate_sum += a_int8[qj] * w_de_0;
intermediate_sum_2nd += a_int8[qj + 32] * w_de_32;
}

C->data_ptr[i * n + j] += (float)intermediate_sum * s_a * s_w;
C->data_ptr[i * n + j] += (float)intermediate_sum_2nd * s_a_2nd * s_w_2nd;
}
}
k0 += block_size * 2;
#endif
}
}
}
}
} // namespace matmul
2 changes: 1 addition & 1 deletion transformer/Makefile
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Compiler and flags
CXX = g++
CXXFLAGS = -std=c++11 -pthread -g -O0 -w
CXXFLAGS = -std=c++11 -pthread -Ofast -w
CXXFLAGS += -DIMP=$(IMP)

# Executable and source files
Expand Down
Loading