From fdc001922898df80fd07901f15a07a9e304ed705 Mon Sep 17 00:00:00 2001 From: Wei-Ming Chen Date: Sun, 30 Apr 2023 23:56:54 -0400 Subject: [PATCH] Basic implementation for avx (#69) * add baseline implementation * support onednn * kernel timing * performance * int8 onednn * utils * minor * fix * mnior * avx imp for int8 gemm * add file * fix * cuda compiler flags * compilation for int8 * minor * minor * minor * 2x2 unroll * omp imp * unroll 32 elements * min/max params * bias support * minor * minor * fix * bf32 fp32 ops * bmm * fix * rounding * fix --- experimental/matmul_optimization/src/Makefile | 96 ++ .../matmul_optimization/src/README.md | 3 + .../matmul_optimization/src/benchmark/main.cc | 241 +++++ .../src/benchmark/main_int8.cc | 114 +++ .../matmul_optimization/src/lib/matmul.cu | 94 ++ .../matmul_optimization/src/lib/matmul.h | 88 ++ .../src/lib/matmul_avx_int8.cc | 898 ++++++++++++++++++ .../matmul_optimization/src/lib/matmul_imp.cc | 446 +++++++++ .../src/lib/matmul_int8.cc | 31 + .../src/lib/matmul_onednn.cc | 145 +++ .../matmul_optimization/src/lib/utils.cc | 12 + 11 files changed, 2168 insertions(+) create mode 100644 experimental/matmul_optimization/src/Makefile create mode 100644 experimental/matmul_optimization/src/README.md create mode 100644 experimental/matmul_optimization/src/benchmark/main.cc create mode 100644 experimental/matmul_optimization/src/benchmark/main_int8.cc create mode 100644 experimental/matmul_optimization/src/lib/matmul.cu create mode 100644 experimental/matmul_optimization/src/lib/matmul.h create mode 100644 experimental/matmul_optimization/src/lib/matmul_avx_int8.cc create mode 100644 experimental/matmul_optimization/src/lib/matmul_imp.cc create mode 100644 experimental/matmul_optimization/src/lib/matmul_int8.cc create mode 100644 experimental/matmul_optimization/src/lib/matmul_onednn.cc create mode 100644 experimental/matmul_optimization/src/lib/utils.cc diff --git a/experimental/matmul_optimization/src/Makefile b/experimental/matmul_optimization/src/Makefile new file mode 100644 index 00000000..9477b84b --- /dev/null +++ b/experimental/matmul_optimization/src/Makefile @@ -0,0 +1,96 @@ +# Check operating system +OS := $(shell uname) + +# OneDNN availability +ONEDNN_AVAILABLE = +ifeq ($(OS), Darwin) # macOS + $(info Detected macOS) + ONEDNN_AVAILABLE := $(shell otool -L /usr/local/lib/libdnnl* 2> /dev/null) +else ifeq ($(OS), Linux) # Ubuntu or other Linux distributions + $(info Detected Linux) + ONEDNN_AVAILABLE_CHK := $(shell pkg-config --exists dnnl; echo $$?) + ifeq ($(ONEDNN_AVAILABLE_CHK), 0) + ONEDNN_AVAILABLE := $(shell pkg-config --exists onednn 2> /dev/null) # TODO: check this in Linux env + endif +else + $(error Unsupported operating system) +endif + +# Check if CUDA is available +CUDA_AVAILABLE := $(shell command -v /usr/local/cuda/bin/nvcc 2> /dev/null) + +CC_FLAGS = -O3 -std=c++11 #-g +#CC_FLAGS = -O3 -std=c++11 -Xclang -fopenmp -g +# Compiler and flags +ifdef CUDA_AVAILABLE + CC = /usr/local/cuda/bin/nvcc + CC_FLAGS += -DCUDA_ENABLE + $(info CUDA is available) +else + CC = g++ + CC_FLAGS += -mavx2 -mfma +endif +ifdef ONEDNN_AVAILABLE + CC_FLAGS += -DONEDNN_ENABLE + $(info ONEDNN is available) +endif + +# Include directories +# INCLUDE_DIRS = -I./ -I/usr/local/opt/libomp/include +INCLUDE_DIRS = -I./ + +# Library directories +LIBRARY_DIRS = -L/usr/local/cuda/lib64 + +# Library flag +LDFLAGS = +ifdef ONEDNN_AVAILABLE +LDFLAGS += -ldnnl +endif + +# TODO: openmp flag +OMP_FLAGS = -L/usr/local/opt/libomp/lib/ -lomp +# LDFLAGS += $(OMP_FLAGS + +# Files +TARGET = benchmark_run +CUDA_SRCS = lib/matmul.cu +CPP_SRCS = benchmark/main.cc lib/matmul_imp.cc lib/utils.cc lib/matmul_int8.cc lib/matmul_avx_int8.cc +ONEDNN_SRCS = lib/matmul_onednn.cc + +# Objects +OBJS = $(CPP_SRCS:.cc=.o) +INT8_OBJS = $(INT8_CPP_SRCS:.cc=.o) +ifdef CUDA_AVAILABLE +OBJS += $(CUDA_SRCS:.cu=.o) +endif +ifdef ONEDNN_AVAILABLE +OBJS += $(ONEDNN_SRCS:.cc=.o) +INT8_OBJS += $(ONEDNN_SRCS:.cc=.o) +endif + + +# $(info ONEDNN_AVAILABLE: $(ONEDNN_AVAILABLE)) +$(info CC_FLAGS: $(CC_FLAGS)) + + +# Targets +all: $(TARGET) + +$(TARGET): $(OBJS) + $(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -o $(TARGET) $(OBJS) + +%.o: %.cu + $(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -c $< -o $@ + +ifdef CUDA_AVAILABLE +%.o: %.cc + $(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -x cu -c $< -o $@ +else +%.o: %.cc + $(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -c $< -o $@ + #$(CC) $(CC_FLAGS) $(INCLUDE_DIRS) $(LDFLAGS) -c $< -o $@ $(OMP_FLAGS) +endif + +clean: + rm -f $(TARGET) $(OBJS) diff --git a/experimental/matmul_optimization/src/README.md b/experimental/matmul_optimization/src/README.md new file mode 100644 index 00000000..929ba4b8 --- /dev/null +++ b/experimental/matmul_optimization/src/README.md @@ -0,0 +1,3 @@ +# Build onednn (enable openmp on mac) + +cmake .. -DOpenMP_C_FLAGS="-Xclang -fopenmp -I/usr/local/opt/libomp/include" -DOpenMP_C_LIB_NAMES="libomp" -DDNNL_CPU_RUNTIME=OMP -DOpenMP_CXX_FLAGS="-Xclang -fopenmp -I/usr/local/opt/libomp/include" -DOpenMP_CXX_LIB_NAMES="libomp" -DOpenMP_libomp_LIBRARY=/usr/local/opt/libomp/lib/libomp.dylib -DCMAKE_SHARED_LINKER_FLAGS="-L/usr/local/opt/libomp/lib/ -lomp -Wl,-rpath,/usr/local/opt/libomp/lib/" diff --git a/experimental/matmul_optimization/src/benchmark/main.cc b/experimental/matmul_optimization/src/benchmark/main.cc new file mode 100644 index 00000000..5b91692e --- /dev/null +++ b/experimental/matmul_optimization/src/benchmark/main.cc @@ -0,0 +1,241 @@ +#include +#include + +#include +#include + +#include "lib/matmul.h" + +#define BLK_SIZE 16 +#define MAX_PRECISION_ERROR 0.01 + +#define M 1024 +#define N 1024 +#define K 1024 +#define A_ROW M +#define A_COLUMN K +#define B_ROW K +#define B_COLUMN N +#define C_ROW M +#define C_COLUMN N +#define NUM_THREAD 16 + +float MAT_A[A_ROW * A_COLUMN]; +float MAT_B[B_ROW * B_COLUMN]; +float transpose_B[B_ROW * B_COLUMN]; +float native_C[C_ROW * C_COLUMN]; +float output_C[C_ROW * C_COLUMN]; + +int8_t MAT_A_s8[A_ROW * A_COLUMN]; +int8_t MAT_B_s8[B_ROW * B_COLUMN]; +int32_t bias_s32[C_COLUMN]; +int8_t transpose_B_s8[B_ROW * B_COLUMN]; +int8_t native_C_s8[C_ROW * C_COLUMN]; +int8_t output_C_s8[C_ROW * C_COLUMN]; + +bool check_identical(float matA[], float matB[], int size) { + for (int i = 0; i < size; i++) { + if (abs((matA[i] - matB[i]) / (matA[i])) > MAX_PRECISION_ERROR) { + printf("%d: %f, %f", i, matA[i], matB[i]); + return false; + } + } + return true; +} + +bool check_identical(int8_t matA[], int8_t matB[], int size) { + for (int i = 0; i < size; i++) { + if (matA[i] != matB[i]) { + printf("%d: %d, %d", i, matA[i], matB[i]); + return false; + } + } + return true; +} + +template +void dump_integer_array(T matA[], int size) { + for (int i = 0; i < size; i++) { + printf("%d,", matA[i]); + } + printf("\n"); +} + +void initialize_matrix(float A[], int size) { + for (int i = 0; i < size; i++) { + A[i] = (float)(rand()) / (float)(RAND_MAX); + } +} + +void initialize_matrix(int8_t A[], int size) { + for (int i = 0; i < size; i++) { + // A[i] = (rand() % 2) - 1; + A[i] = (rand() % 2); + } +} + +void initialize_matrix(int32_t A[], int size) { + for (int i = 0; i < size; i++) { + // A[i] = (rand() % 2) - 1; + A[i] = (rand() % 2); + } +} + +using namespace matmul; + +int main() { + // initialize + initialize_matrix(MAT_A, A_ROW * A_COLUMN); + initialize_matrix(MAT_B, B_ROW * B_COLUMN); + initialize_matrix(native_C, C_ROW * C_COLUMN); + + initialize_matrix(MAT_A_s8, A_ROW * A_COLUMN); + initialize_matrix(MAT_B_s8, B_ROW * B_COLUMN); + initialize_matrix(native_C_s8, C_ROW * C_COLUMN); + // initialize_matrix(bias_s32, C_ROW * C_COLUMN); + + MatmulOperator matmul_op = MatmulOperator(); + + struct matmul_params params, params_int8; + params.A.row = A_ROW; + params.A.column = A_COLUMN; + params.A.data_ptr = MAT_A; + params.B.row = B_ROW; + params.B.column = B_COLUMN; + params.B.data_ptr = MAT_B; + params.C.row = C_ROW; + params.C.column = C_COLUMN; + params.opt_params.blk_size = BLK_SIZE; + params.opt_params.num_thread = NUM_THREAD; + + // int8 + params_int8.A.row = A_ROW; + params_int8.A.column = A_COLUMN; + params_int8.A.int8_data_ptr = MAT_A_s8; + params_int8.A.qparams.scale = 1.0; + params_int8.A.qparams.zero_point = 0; + params_int8.B.row = B_ROW; + params_int8.B.column = B_COLUMN; + params_int8.B.int8_data_ptr = MAT_B_s8; + params_int8.B.qparams.scale = 1.0; + params_int8.B.qparams.zero_point = 0; + params_int8.C.row = C_ROW; + params_int8.C.column = C_COLUMN; + params_int8.C.int8_data_ptr = native_C_s8; + params_int8.C.qparams.scale = 1.0; + params_int8.C.qparams.q_max = 127; + params_int8.C.qparams.q_min = -128; + params_int8.C.qparams.zero_point = 0; + params_int8.opt_params.blk_size = BLK_SIZE; + params_int8.opt_params.num_thread = NUM_THREAD; + params_int8.bias.row = 1; + params_int8.bias.column = C_COLUMN; + params_int8.bias.int32_data_ptr = bias_s32; + + // Baseline + params.C.data_ptr = native_C; + matmul_op.evaluate(MatmulOperator::NAIVE, ¶ms); + + params.C.data_ptr = output_C; + // unrolling + matmul_op.evaluate(MatmulOperator::UNROLL, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_unrolling\n"); + + // reordering + matmul_op.evaluate(MatmulOperator::REORDER, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_reordering\n"); + + // tiling + matmul_op.evaluate(MatmulOperator::TILING, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_tiling\n"); + + // multithreading + matmul_op.evaluate(MatmulOperator::MULTITHREAD, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_multithreading\n"); + + // transpose + matmul_op.evaluate(MatmulOperator::TRANSPOSE, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_transpose\n"); + + // transpose + simd + initialize_matrix(output_C, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::TRANSPOSE_SIMD, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_transpose_simd\n"); + +// cuda +#ifdef CUDA_ENABLE + matmul_op.evaluate(MatmulOperator::CUDA, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_cuda\n"); +#endif + +// ONEDNN +#ifdef ONEDNN_ENABLE + initialize_matrix(output_C, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::ONEDNN_FP32, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("\nincorrect output of mat_mul_onedenn\n"); +#endif + + // For fast, we need to transpose B first + for (int i = 0; i < B_COLUMN; i++) + for (int j = 0; j < B_ROW; j++) transpose_B[i * B_ROW + j] = MAT_B[j * B_COLUMN + i]; + params.B.column = B_ROW; + params.B.row = B_COLUMN; + params.B.data_ptr = transpose_B; + params.opt_params.blk_size = BLK_SIZE; + params.opt_params.num_thread = NUM_THREAD; + + // fast + initialize_matrix(output_C, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::FAST, ¶ms); + if (!check_identical(native_C, output_C, C_ROW * C_COLUMN)) printf("incorrect output of mat_mul_fast\n"); + + // int8 + matmul_op.evaluate(MatmulOperator::INT8_BASELINE, ¶ms_int8); + + params_int8.C.int8_data_ptr = output_C_s8; + + // For int8 SIMD, we need to transpose B first + for (int i = 0; i < B_COLUMN; i++) + for (int j = 0; j < B_ROW; j++) transpose_B_s8[i * B_ROW + j] = MAT_B_s8[j * B_COLUMN + i]; + + params_int8.B.int8_data_ptr = transpose_B_s8; + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::INT8_AVX, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_avx_int8\n"); + + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_avx_int8_fast\n"); + + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST_2x2, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_avx_int8_fast_2x2\n"); + + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST_2x2_32UNROLL, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_avx_int8_fast_2x2_32unroll\n"); + + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST_2x2_OMP, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_avx_int8_fast_2x2_omp\n"); + +// ONEDNN +#ifdef ONEDNN_ENABLE + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::ONEDNN_INT8, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_onednn_int8\n"); +#endif + // Debugging + // dump_integer_array(MAT_A_s8, A_ROW * A_COLUMN); + // dump_integer_array(MAT_B_s8, B_ROW * B_COLUMN); + // dump_integer_array(native_C_s8, C_ROW * C_COLUMN); + // dump_integer_array(output_C_s8, C_ROW * C_COLUMN); + + return 0; +} diff --git a/experimental/matmul_optimization/src/benchmark/main_int8.cc b/experimental/matmul_optimization/src/benchmark/main_int8.cc new file mode 100644 index 00000000..c0c5f3ff --- /dev/null +++ b/experimental/matmul_optimization/src/benchmark/main_int8.cc @@ -0,0 +1,114 @@ +#include +#include + +#include +#include + +#include "lib/matmul.h" + +#define BLK_SIZE 64 +#define MAX_PRECISION_ERROR 0.01 + +#define A_ROW 1024 +#define A_COLUMN 1024 +#define B_ROW 1024 +#define B_COLUMN 1024 +#define C_ROW 1024 +#define C_COLUMN 1024 +#define NUM_THREAD 16 + +int8_t MAT_A_s8[A_ROW * A_COLUMN]; +int8_t MAT_B_s8[B_ROW * B_COLUMN]; +int8_t transpose_B_s8[B_ROW * B_COLUMN]; +int8_t native_C_s8[C_ROW * C_COLUMN]; +int8_t output_C_s8[C_ROW * C_COLUMN]; + +bool check_identical(int8_t matA[], int8_t matB[], int size) { + for (int i = 0; i < size; i++) { + if (matA[i] != matB[i]) { + printf("%d: %d, %d", i, matA[i], matB[i]); + return false; + } + } + return true; +} + +template +void dump_integer_array(T matA[], int size) { + for (int i = 0; i < size; i++) { + printf("%d,", matA[i]); + } + printf("\n"); +} + +void initialize_matrix(int8_t A[], int size) { + for (int i = 0; i < size; i++) { + // A[i] = (rand() % 2) - 1; + A[i] = (rand() % 2); + } +} + +using namespace matmul; + +int main() { + // initialize + initialize_matrix(MAT_A_s8, A_ROW * A_COLUMN); + initialize_matrix(MAT_B_s8, B_ROW * B_COLUMN); + initialize_matrix(native_C_s8, C_ROW * C_COLUMN); + + MatmulOperator matmul_op = MatmulOperator(); + + struct matmul_params params_int8; + // int8 + params_int8.A.row = A_ROW; + params_int8.A.column = A_COLUMN; + params_int8.A.int8_data_ptr = MAT_A_s8; + params_int8.A.qparams.scale = 1.0; + params_int8.A.qparams.zero_point = 0; + params_int8.B.row = B_ROW; + params_int8.B.column = B_COLUMN; + params_int8.B.int8_data_ptr = MAT_B_s8; + params_int8.B.qparams.scale = 1.0; + params_int8.B.qparams.zero_point = 0; + params_int8.C.row = C_ROW; + params_int8.C.column = C_COLUMN; + params_int8.C.int8_data_ptr = native_C_s8; + params_int8.C.qparams.scale = 1.0; + params_int8.C.qparams.zero_point = 0; + params_int8.opt_params.blk_size = BLK_SIZE; + params_int8.opt_params.num_thread = NUM_THREAD; + + // int8 + matmul_op.evaluate(MatmulOperator::INT8_BASELINE, ¶ms_int8); + +// ONEDNN +#ifdef ONEDNN_ENABLE + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + params_int8.C.int8_data_ptr = output_C_s8; + matmul_op.evaluate(MatmulOperator::ONEDNN_INT8, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_onednn_int8\n"); +#endif + + // For int8 SIMD, we need to transpose B first + for (int i = 0; i < B_COLUMN; i++) + for (int j = 0; j < B_ROW; j++) transpose_B_s8[i * B_ROW + j] = MAT_B_s8[j * B_COLUMN + i]; + + params_int8.B.int8_data_ptr = transpose_B_s8; + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::INT8_AVX, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_avx_int8\n"); + + initialize_matrix(output_C_s8, C_ROW * C_COLUMN); + matmul_op.evaluate(MatmulOperator::INT8_AVX_FAST, ¶ms_int8); + if (!check_identical(native_C_s8, output_C_s8, C_ROW * C_COLUMN)) + printf("incorrect output from mat_mul_avx_int8_fast\n"); + // Debugging + // dump_integer_array(MAT_A_s8, A_ROW * A_COLUMN); + // dump_integer_array(MAT_B_s8, B_ROW * B_COLUMN); + // dump_integer_array(native_C_s8, C_ROW * C_COLUMN); + // dump_integer_array(output_C_s8, C_ROW * C_COLUMN); + + return 0; +} diff --git a/experimental/matmul_optimization/src/lib/matmul.cu b/experimental/matmul_optimization/src/lib/matmul.cu new file mode 100644 index 00000000..db850e0f --- /dev/null +++ b/experimental/matmul_optimization/src/lib/matmul.cu @@ -0,0 +1,94 @@ +#include +#include +#include "matmul.h" +#include +#include +#include +#include + +const int threadDim = 32; +const int TILE_SIZE = threadDim; +__global__ void matrixMul_blockC(float *A, float *B, float *C, int A_row, int A_column, int B_column){ + int i = blockIdx.x * blockDim.x + threadIdx.x; + int j = blockIdx.y * blockDim.y + threadIdx.y; + + float acc = 0; + for (int k = 0; k < A_column; k++) + acc += A[j * A_column + k] * B[k * B_column + i]; + C[j * B_column +i] = acc; +} + +__global__ void matrixMultiplyShared(const float *A, const float *B, float *C, int A_row, int A_column, int B_column) { + int row = blockIdx.y * blockDim.y + threadIdx.y; + int col = blockIdx.x * blockDim.x + threadIdx.x; + + __shared__ float As[TILE_SIZE][TILE_SIZE]; + __shared__ float Bs[TILE_SIZE][TILE_SIZE]; + + float value = 0; + + for (int i = 0; i < A_column / TILE_SIZE; i++){ + As[threadIdx.y][threadIdx.x] = A[(blockIdx.y * TILE_SIZE + threadIdx.y) * A_column + TILE_SIZE * i + threadIdx.x]; + Bs[threadIdx.y][threadIdx.x] = B[(i * TILE_SIZE + threadIdx.y) * B_column + blockIdx.x * TILE_SIZE + threadIdx.x]; + + __syncthreads(); + + for (int k = 0; k < TILE_SIZE; k++) + value += As[threadIdx.y][k] * Bs[k][threadIdx.x]; + + __syncthreads(); + } + + + C[row * B_column + col] = value; +} + +namespace matmul{ + + void MatmulOperator::mat_mul_cuda(const struct matmul_params *params){ + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + assert(A->column == B->row); + assert(C->column == B->column); + assert(C->row == A->row); + + float *d_A; + float *d_B; + float *d_C; + + // Initailize C + /*for (int i = 0; i < C->row; i++) + for (int j = 0; j < C->column; j++) + C->data_ptr[j + C->column * i] = 0;*/ + + // Allocate memory + cudaMalloc(&d_A, A->column*A->row*sizeof(float)); + cudaMalloc(&d_B, B->column*B->row*sizeof(float)); + cudaMalloc(&d_C, C->column*C->row*sizeof(float)); + + // Copy data to GPU + cudaMemcpy(d_A, A->data_ptr, A->column*A->row*sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(d_B, B->data_ptr, B->column*B->row*sizeof(float), cudaMemcpyHostToDevice); + cudaMemcpy(d_C, C->data_ptr, C->column*C->row*sizeof(float), cudaMemcpyHostToDevice); + + // Make sure we can break the input matrix into blocks + assert(A->column % threadDim == 0); + assert(A->row % threadDim == 0); + assert(B->column % threadDim == 0); + const dim3 threadsPerBlock(threadDim, threadDim); + const dim3 numBlocks(C->column / threadsPerBlock.x, C->row / threadsPerBlock.y); + + // Invoke the cuda imp. + + struct timeval start, end; + gettimeofday(&start, NULL); + //matrixMul_blockC<<< numBlocks, threadsPerBlock>>>(d_A, d_B, d_C, A->row, A->column, B->column); + matrixMultiplyShared<<< numBlocks, threadsPerBlock>>>(d_A, d_B, d_C, A->row, A->column, B->column); + cudaDeviceSynchronize(); + gettimeofday(&end, NULL); + int us = interval_to_us(&start, &end); + std::cout << "cuda kernel: " << us / 1000 << " ms" << std::endl; + + // Get the result back + cudaMemcpy(C->data_ptr, d_C, C->column*C->row*sizeof(float), cudaMemcpyDeviceToHost); + } +} diff --git a/experimental/matmul_optimization/src/lib/matmul.h b/experimental/matmul_optimization/src/lib/matmul.h new file mode 100644 index 00000000..9f6d3193 --- /dev/null +++ b/experimental/matmul_optimization/src/lib/matmul.h @@ -0,0 +1,88 @@ +#include +#include +// Data structures +struct quantization_params { + float scale; + bool per_channel = false; + int32_t zero_point; + int8_t q_min = -128, q_max = 127; +}; + +struct matrix { + int row; + int column; + float *data_ptr; + int8_t *int8_data_ptr; + int32_t *int32_data_ptr; + struct quantization_params qparams; +}; + +struct optimization_params { + int blk_size; + int num_thread = 8; +}; + +struct matmul_params { + struct matrix A, B, C, bias; + struct optimization_params opt_params; +}; + +struct thread_args { + const struct matrix *A; + const struct matrix *B; + const struct matrix *C; + const struct matmul_params *params; + int start_i, end_i, blk_size; +}; + +#define MAX(A, B) ((A) > (B) ? (A) : (B)) +#define MIN(A, B) ((A) < (B) ? (A) : (B)) +namespace matmul { +class MatmulOperator { + public: + enum IMP_TYPE { + NAIVE = 0, + UNROLL = 1, + REORDER = 2, + TILING = 3, + MULTITHREAD = 4, + TRANSPOSE = 5, + TRANSPOSE_SIMD = 6, + FAST = 7, + CUDA = 8, + ONEDNN_FP32 = 9, + INT8_BASELINE = 10, + ONEDNN_INT8 = 11, + INT8_AVX = 12, + INT8_AVX_FAST = 13, + INT8_AVX_FAST_2x2 = 14, + INT8_AVX_FAST_2x2_32UNROLL = 15, + INT8_AVX_FAST_2x2_OMP = 16, + }; + void naive_mat_mul(const struct matmul_params *params); + void mat_mul_unrolling(const struct matmul_params *params); + void mat_mul_reordering(const struct matmul_params *params); + void mat_mul_tiling(const struct matmul_params *params); + void mat_mul_multithreading(const struct matmul_params *params); + void mat_mul_transpose(const struct matmul_params *params); + void mat_mul_transpose_simd(const struct matmul_params *params); + void mat_mul_fast(const struct matmul_params *params); + void mat_mul_onednn(const struct matmul_params *params); + void mat_mul_onednn_int8(const struct matmul_params *params); + void naive_mat_mul_int8(const struct matmul_params *params); + void mat_mul_avx_int8(const struct matmul_params *params); + void mat_mul_avx_int8_fast(const struct matmul_params *params); + void mat_mul_avx_int8_fast_2x2(const struct matmul_params *params); + void mat_mul_avx_int8_fast_2x2_32unroll(const struct matmul_params *params); + void mat_mul_avx_int8_fast_2x2_32unroll_nobias(const struct matmul_params *params); + void mat_mul_avx_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params *params); + void mat_mul_avx_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params *params); + void mat_mul_avx_int8_fast_2x2_omp(const struct matmul_params *params); + void mat_mul_cuda(const struct matmul_params *params); + void evaluate(IMP_TYPE type, const struct matmul_params *params); + + private: + float interval_to_us(struct timeval *start, struct timeval *end); + void CHECK_MATRICES(const struct matrix *A, const struct matrix *B, const struct matrix *C); +}; +} // namespace matmul diff --git a/experimental/matmul_optimization/src/lib/matmul_avx_int8.cc b/experimental/matmul_optimization/src/lib/matmul_avx_int8.cc new file mode 100644 index 00000000..f73aab7b --- /dev/null +++ b/experimental/matmul_optimization/src/lib/matmul_avx_int8.cc @@ -0,0 +1,898 @@ +#include // AVX instrintics + +#include +#include +#include +#include + +#include "matmul.h" +// #include // currently it is bugged + +namespace matmul { +void dump_64x8_signed(__m256i &target, char *title) { + int8_t *ptr = (int8_t *)⌖ + + printf("%s:", title); + for (int i = 0; i < 64; i++) { + printf("%3d, ", *ptr++); + } + printf("\n"); +} + +void dump_64x8_unsigned(__m256i &target, char *title) { + uint8_t *ptr = (uint8_t *)⌖ + + printf("%s:", title); + for (int i = 0; i < 64; i++) { + printf("%3d, ", *ptr++); + } + printf("\n"); +} + +void dump_32x16_signed(__m256i &target, char *title) { + int16_t *ptr = (int16_t *)⌖ + + printf("%s:", title); + for (int i = 0; i < 32; i++) { + printf("%d, ", *ptr++); + } + printf("\n"); +} + +// element-wise multiply two vectors of 64 8-bit integers and return the accumulate 32-bit result +// We need to assume int8 is in the range of 127 <-> - 127, otherwise, we will expect overflow in some case +// e,g., a[i] = b[i] = -128 -> a[i] * b[i] = 32768 which is not in the range of int16_t(-32768, 32767) +__m256i zero_vec = _mm256_setzero_si256(); +__m256i multiply_signed_int8(__m256i &a, __m256i &b, __m256i &a2, __m256i &b2) { + __m256i a_sign_mask = _mm256_cmpgt_epi8(zero_vec, a); // set 0xFF if zero_vec[i] > a[i] + __m256i b_sign_mask = _mm256_cmpgt_epi8(zero_vec, b); // set 0xFF if zero_vec[i] > a[i] + __m256i a2_sign_mask = _mm256_cmpgt_epi8(zero_vec, a2); // set 0xFF if zero_vec[i] > a[i] + __m256i b2_sign_mask = _mm256_cmpgt_epi8(zero_vec, b2); // set 0xFF if zero_vec[i] > a[i] + + // Compute the two's complement of a, put it here for higher throughput with good instruction dep. + __m256i b_abs = _mm256_abs_epi8(b); + __m256i b2_abs = _mm256_abs_epi8(b2); + __m256i a_abs = _mm256_abs_epi8(a); + __m256i a2_abs = _mm256_abs_epi8(a2); + __m256i b_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), b_abs); + __m256i b2_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), b2_abs); + + // Manipulate the `sign` of B to represent the sign of the 16 bit result + __m256i sign_mask_a_sub_b = _mm256_sub_epi8(a_sign_mask, b_sign_mask); + __m256i sign_mask_a2_sub_b2 = _mm256_sub_epi8(a2_sign_mask, b2_sign_mask); + __m256i sign_mask = + _mm256_cmpeq_epi8(sign_mask_a_sub_b, zero_vec); // sign_mask[i] if a[i] and b[i] have different sign bits + __m256i sign_mask2 = _mm256_cmpeq_epi8(sign_mask_a2_sub_b2, zero_vec); + __m256i corrected_b = _mm256_blendv_epi8(b_negated, b_abs, sign_mask); + __m256i corrected_b2 = _mm256_blendv_epi8(b2_negated, b2_abs, sign_mask2); + + // Multiply the absolute values of a_abs (unsigned 8-bit integers) and corrected_b (signed 8-bit integers) + __m256i product_16x16 = _mm256_maddubs_epi16(a_abs, corrected_b); + __m256i product_16x16_2 = _mm256_maddubs_epi16(a2_abs, corrected_b2); + + // Sign extend the 16-bit integers in vector to 32-bit integers + __m256i a_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16, 0)); + __m256i b_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_2, 0)); + __m256i a_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16, 1)); + __m256i b_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_2, 1)); + + // Element-wise add the 32-bit integer vectors + __m256i sum1 = _mm256_add_epi32(a_ext1, b_ext1); + __m256i sum2 = _mm256_add_epi32(a_ext2, b_ext2); + + __m256i sum_product_8x32 = _mm256_add_epi32(sum1, sum2); + + return sum_product_8x32; +} + +// ([a, a2], [c, c2]) * ([b, b2], [d, d2]) +// acc0 = a * b + a2 * b2, acc2 = a * d + a2 * d2, acc3 = c * b + c * b2, acc4 = c * d + c2 * d2 +void multiply_signed_int8_2x2(__m256i &a, __m256i &b, __m256i &a2, __m256i &b2, __m256i &c, __m256i &c2, __m256i &d, + __m256i &d2, __m256i &acc0, __m256i &acc1, __m256i &acc2, __m256i &acc3) { + __m256i a_sign_mask = _mm256_cmpgt_epi8(zero_vec, a); // set 0xFF if zero_vec[i] > a[i] + __m256i b_sign_mask = _mm256_cmpgt_epi8(zero_vec, b); // set 0xFF if zero_vec[i] > a[i] + __m256i a2_sign_mask = _mm256_cmpgt_epi8(zero_vec, a2); + __m256i b2_sign_mask = _mm256_cmpgt_epi8(zero_vec, b2); + __m256i c_sign_mask = _mm256_cmpgt_epi8(zero_vec, c); + __m256i d_sign_mask = _mm256_cmpgt_epi8(zero_vec, d); + __m256i c2_sign_mask = _mm256_cmpgt_epi8(zero_vec, c2); + __m256i d2_sign_mask = _mm256_cmpgt_epi8(zero_vec, d2); + + // Compute the two's complement of a, put it here for higher throughput with good instruction dep. + __m256i b_abs = _mm256_abs_epi8(b); + __m256i b2_abs = _mm256_abs_epi8(b2); + __m256i a_abs = _mm256_abs_epi8(a); + __m256i a2_abs = _mm256_abs_epi8(a2); + __m256i d_abs = _mm256_abs_epi8(d); + __m256i d2_abs = _mm256_abs_epi8(d2); + __m256i c_abs = _mm256_abs_epi8(c); + __m256i c2_abs = _mm256_abs_epi8(c2); + __m256i b_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), b_abs); + __m256i b2_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), b2_abs); + __m256i d_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), d_abs); + __m256i d2_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), d2_abs); + + // Manipulate the `sign` of B to represent the sign of the 16 bit result + __m256i sign_mask_a_sub_b = _mm256_sub_epi8(a_sign_mask, b_sign_mask); + __m256i sign_mask_a_sub_d = _mm256_sub_epi8(a_sign_mask, d_sign_mask); + __m256i sign_mask_a2_sub_b2 = _mm256_sub_epi8(a2_sign_mask, b2_sign_mask); + __m256i sign_mask_a2_sub_d2 = _mm256_sub_epi8(a2_sign_mask, d2_sign_mask); + __m256i sign_mask_c_sub_b = _mm256_sub_epi8(c_sign_mask, b_sign_mask); + __m256i sign_mask_c_sub_d = _mm256_sub_epi8(c_sign_mask, d_sign_mask); + __m256i sign_mask_c2_sub_b2 = _mm256_sub_epi8(c2_sign_mask, b2_sign_mask); + __m256i sign_mask_c2_sub_d2 = _mm256_sub_epi8(c2_sign_mask, d2_sign_mask); + + // sign_mask[i] if a[i] and b[i] have different sign bits + __m256i sign_mask_ab = _mm256_cmpeq_epi8(sign_mask_a_sub_b, zero_vec); + __m256i sign_mask2_a2_b2 = _mm256_cmpeq_epi8(sign_mask_a2_sub_b2, zero_vec); + __m256i sign_mask_ad = _mm256_cmpeq_epi8(sign_mask_a_sub_d, zero_vec); + __m256i sign_mask2_a2_d2 = _mm256_cmpeq_epi8(sign_mask_a2_sub_d2, zero_vec); + __m256i sign_mask_cb = _mm256_cmpeq_epi8(sign_mask_c_sub_b, zero_vec); + __m256i sign_mask2_c2_b2 = _mm256_cmpeq_epi8(sign_mask_c2_sub_b2, zero_vec); + __m256i sign_mask_cd = _mm256_cmpeq_epi8(sign_mask_c_sub_d, zero_vec); + __m256i sign_mask2_c2_d2 = _mm256_cmpeq_epi8(sign_mask_c2_sub_d2, zero_vec); + + __m256i corrected_ab = _mm256_blendv_epi8(b_negated, b_abs, sign_mask_ab); + __m256i corrected_a2b2 = _mm256_blendv_epi8(b2_negated, b2_abs, sign_mask2_a2_b2); + __m256i corrected_ad = _mm256_blendv_epi8(d_negated, d_abs, sign_mask_ad); + __m256i corrected_a2d2 = _mm256_blendv_epi8(d2_negated, d2_abs, sign_mask2_a2_d2); + __m256i corrected_cb = _mm256_blendv_epi8(b_negated, b_abs, sign_mask_cb); + __m256i corrected_c2b2 = _mm256_blendv_epi8(b2_negated, b2_abs, sign_mask2_c2_b2); + __m256i corrected_cd = _mm256_blendv_epi8(d_negated, d_abs, sign_mask_cd); + __m256i corrected_c2d2 = _mm256_blendv_epi8(d2_negated, d2_abs, sign_mask2_c2_d2); + + // Multiply the absolute values of a_abs (unsigned 8-bit integers) and corrected_b (signed 8-bit integers) + __m256i product_16x16_ab = _mm256_maddubs_epi16(a_abs, corrected_ab); + __m256i product_16x16_ab2 = _mm256_maddubs_epi16(a2_abs, corrected_a2b2); + __m256i product_16x16_ad = _mm256_maddubs_epi16(a_abs, corrected_ad); + __m256i product_16x16_ad2 = _mm256_maddubs_epi16(a2_abs, corrected_a2d2); + __m256i product_16x16_cb = _mm256_maddubs_epi16(c_abs, corrected_cb); + __m256i product_16x16_cb2 = _mm256_maddubs_epi16(c2_abs, corrected_c2b2); + __m256i product_16x16_cd = _mm256_maddubs_epi16(c_abs, corrected_cd); + __m256i product_16x16_cd2 = _mm256_maddubs_epi16(c2_abs, corrected_c2d2); + + // Sign extend the 16-bit integers in vector to 32-bit integers + __m256i ab_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab, 0)); + __m256i ab2_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab2, 0)); + __m256i ab_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab, 1)); + __m256i ab2_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab2, 1)); + __m256i ad_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad, 0)); + __m256i ad2_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad2, 0)); + __m256i ad_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad, 1)); + __m256i ad2_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad2, 1)); + __m256i cb_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb, 0)); + __m256i cb2_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb2, 0)); + __m256i cb_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb, 1)); + __m256i cb2_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb2, 1)); + __m256i cd_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd, 0)); + __m256i cd2_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd2, 0)); + __m256i cd_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd, 1)); + __m256i cd2_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd2, 1)); + + // Element-wise add the 32-bit integer vectors + // acc0 += a * b + a2 * b2, acc2 += a * d + a2 * d2, acc3 += c * b + c * b2, acc4 += c * d + c2 * d2 + acc0 = _mm256_add_epi32(acc0, + _mm256_add_epi32(_mm256_add_epi32(ab_ext1, ab2_ext1), _mm256_add_epi32(ab_ext2, ab2_ext2))); + acc1 = _mm256_add_epi32(acc1, + _mm256_add_epi32(_mm256_add_epi32(ad_ext1, ad2_ext1), _mm256_add_epi32(ad_ext2, ad2_ext2))); + acc2 = _mm256_add_epi32(acc2, + _mm256_add_epi32(_mm256_add_epi32(cb_ext1, cb2_ext1), _mm256_add_epi32(cb_ext2, cb2_ext2))); + acc3 = _mm256_add_epi32(acc3, + _mm256_add_epi32(_mm256_add_epi32(cd_ext1, cd2_ext1), _mm256_add_epi32(cd_ext2, cd2_ext2))); +} + +static inline void multiply_signed_int8_2x2_32epi(__m256i &a, __m256i &b, __m256i &c, __m256i &d, __m256i &acc0, + __m256i &acc1, __m256i &acc2, __m256i &acc3) { + __m256i a_sign_mask = _mm256_cmpgt_epi8(zero_vec, a); // set 0xFF if zero_vec[i] > a[i] + __m256i b_sign_mask = _mm256_cmpgt_epi8(zero_vec, b); // set 0xFF if zero_vec[i] > a[i] + __m256i c_sign_mask = _mm256_cmpgt_epi8(zero_vec, c); + __m256i d_sign_mask = _mm256_cmpgt_epi8(zero_vec, d); + + // Compute the two's complement of a, put it here for higher throughput with good instruction dep. + __m256i b_abs = _mm256_abs_epi8(b); + __m256i a_abs = _mm256_abs_epi8(a); + __m256i d_abs = _mm256_abs_epi8(d); + __m256i c_abs = _mm256_abs_epi8(c); + __m256i b_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), b_abs); + __m256i d_negated = _mm256_sub_epi8(_mm256_set1_epi8(0), d_abs); + + // Manipulate the `sign` of B to represent the sign of the 16 bit result + __m256i sign_mask_a_sub_b = _mm256_sub_epi8(a_sign_mask, b_sign_mask); + __m256i sign_mask_a_sub_d = _mm256_sub_epi8(a_sign_mask, d_sign_mask); + __m256i sign_mask_c_sub_b = _mm256_sub_epi8(c_sign_mask, b_sign_mask); + __m256i sign_mask_c_sub_d = _mm256_sub_epi8(c_sign_mask, d_sign_mask); + + // sign_mask[i] if a[i] and b[i] have different sign bits + __m256i sign_mask_ab = _mm256_cmpeq_epi8(sign_mask_a_sub_b, zero_vec); + __m256i sign_mask_ad = _mm256_cmpeq_epi8(sign_mask_a_sub_d, zero_vec); + __m256i sign_mask_cb = _mm256_cmpeq_epi8(sign_mask_c_sub_b, zero_vec); + __m256i sign_mask_cd = _mm256_cmpeq_epi8(sign_mask_c_sub_d, zero_vec); + + __m256i corrected_ab = _mm256_blendv_epi8(b_negated, b_abs, sign_mask_ab); + __m256i corrected_ad = _mm256_blendv_epi8(d_negated, d_abs, sign_mask_ad); + __m256i corrected_cb = _mm256_blendv_epi8(b_negated, b_abs, sign_mask_cb); + __m256i corrected_cd = _mm256_blendv_epi8(d_negated, d_abs, sign_mask_cd); + + // Multiply the absolute values of a_abs (unsigned 8-bit integers) and corrected_b (signed 8-bit integers) + __m256i product_16x16_ab = _mm256_maddubs_epi16(a_abs, corrected_ab); + __m256i product_16x16_ad = _mm256_maddubs_epi16(a_abs, corrected_ad); + __m256i product_16x16_cb = _mm256_maddubs_epi16(c_abs, corrected_cb); + __m256i product_16x16_cd = _mm256_maddubs_epi16(c_abs, corrected_cd); + + // Sign extend the 16-bit integers in vector to 32-bit integers + __m256i ab_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab, 0)); + __m256i ab_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab, 1)); + __m256i ad_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad, 0)); + __m256i ad_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad, 1)); + __m256i cb_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb, 0)); + __m256i cb_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb, 1)); + __m256i cd_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd, 0)); + __m256i cd_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd, 1)); + + // Element-wise add the 32-bit integer vectors + // acc0 += a * b + a2 * b2, acc2 += a * d + a2 * d2, acc3 += c * b + c * b2, acc4 += c * d + c2 * d2 + acc0 = _mm256_add_epi32(acc0, _mm256_add_epi32(ab_ext1, ab_ext2)); + acc1 = _mm256_add_epi32(acc1, _mm256_add_epi32(ad_ext1, ad_ext2)); + acc2 = _mm256_add_epi32(acc2, _mm256_add_epi32(cb_ext1, cb_ext2)); + acc3 = _mm256_add_epi32(acc3, _mm256_add_epi32(cd_ext1, cd_ext2)); +} + +static inline void multiply_signed_int8_2x2_32epi_of(__m256i &a, __m256i &b, __m256i &c, __m256i &d, __m256i &acc0, + __m256i &acc1, __m256i &acc2, __m256i &acc3) { + // Multiply the absolute values of a_abs (unsigned 8-bit integers) and corrected_b (signed 8-bit integers) + __m256i product_16x16_ab = _mm256_maddubs_epi16(a, b); + __m256i product_16x16_ad = _mm256_maddubs_epi16(a, d); + __m256i product_16x16_cb = _mm256_maddubs_epi16(c, b); + __m256i product_16x16_cd = _mm256_maddubs_epi16(c, d); + + // Sign extend the 16-bit integers in vector to 32-bit integers + __m256i ab_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab, 0)); + __m256i ab_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ab, 1)); + __m256i ad_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad, 0)); + __m256i ad_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_ad, 1)); + __m256i cb_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb, 0)); + __m256i cb_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cb, 1)); + __m256i cd_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd, 0)); + __m256i cd_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_cd, 1)); + + // Element-wise add the 32-bit integer vectors + // acc0 += a * b + a2 * b2, acc2 += a * d + a2 * d2, acc3 += c * b + c * b2, acc4 += c * d + c2 * d2 + acc0 = _mm256_add_epi32(acc0, _mm256_add_epi32(ab_ext1, ab_ext2)); + acc1 = _mm256_add_epi32(acc1, _mm256_add_epi32(ad_ext1, ad_ext2)); + acc2 = _mm256_add_epi32(acc2, _mm256_add_epi32(cb_ext1, cb_ext2)); + acc3 = _mm256_add_epi32(acc3, _mm256_add_epi32(cd_ext1, cd_ext2)); +} +// Note: This implementation could have potential overflow! +// __m256i multiply_signed_int8(__m256i &a, __m256i &b, __m256i &a2, __m256i &b2) { +// // Multiply the absolute values of a_abs (unsigned 8-bit integers) and corrected_b (signed 8-bit integers) +// __m256i product_16x16 = _mm256_maddubs_epi16(a, b); +// __m256i product_16x16_2 = _mm256_maddubs_epi16(a2, b2); + +// // Sign extend the 16-bit integers in vector to 32-bit integers +// __m256i a_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16, 0)); +// __m256i b_ext1 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_2, 0)); +// __m256i a_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16, 1)); +// __m256i b_ext2 = _mm256_cvtepi16_epi32(_mm256_extracti128_si256(product_16x16_2, 1)); + +// // Element-wise add the 32-bit integer vectors +// __m256i sum1 = _mm256_add_epi32(a_ext1, b_ext1); +// __m256i sum2 = _mm256_add_epi32(a_ext2, b_ext2); + +// __m256i sum_product_8x32 = _mm256_add_epi32(sum1, sum2); + +// return sum_product_8x32; +// } + +void MatmulOperator::mat_mul_avx_int8(const struct matmul_params *params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + assert(A->column % 64 == 0); + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j++) { + int acc = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 64) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]), + aa2 = _mm256_loadu_si256((const __m256i_u *)(&data_A[i * A->column + k + 32])); + // assume B is transposed + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]), + bb2 = _mm256_loadu_si256((const __m256i_u *)(&data_B[j * B->row + k + 32])); + + acc0_8x32 = _mm256_add_epi32(acc0_8x32, multiply_signed_int8(aa, bb, aa2, bb2)); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc = (int32_t)((float)acc * effective_scale); + acc -= C_zp; + acc = MAX(acc, q_min); + acc = MIN(acc, q_max); + data_C[i * C->column + j] = (int8_t)acc; + } +} + +void *mat_mul_avx_int8_thread_func(void *args) { + int i, j, k; + struct thread_args *thread_args = (struct thread_args *)args; + const struct matmul_params *params = thread_args->params; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + int start_i = thread_args->start_i, end_i = thread_args->end_i; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + for (i = start_i; i < end_i; i++) + for (j = 0; j < C->column; j++) { + int acc = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 64) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]), + aa2 = _mm256_loadu_si256((const __m256i_u *)(&data_A[i * A->column + k + 32])); + // assume B is transposed + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]), + bb2 = _mm256_loadu_si256((const __m256i_u *)(&data_B[j * B->row + k + 32])); + + acc0_8x32 = _mm256_add_epi32(acc0_8x32, multiply_signed_int8(aa, bb, aa2, bb2)); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc = (int32_t)std::round((float)acc * effective_scale); + acc -= C_zp; + acc = MAX(acc, q_min); + acc = MIN(acc, q_max); + data_C[i * C->column + j] = (int8_t)acc; + } + return NULL; +} + +void MatmulOperator::mat_mul_avx_int8_fast(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + assert(params->A.column % 64 == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (params->C.row / num_thread); + threads_args[j].end_i = (j + 1) * (params->C.row / num_thread); + threads_args[j].blk_size = params->opt_params.blk_size; + threads_args[j].params = params; + pthread_create(&thread_pool[j], NULL, mat_mul_avx_int8_thread_func, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +void *mat_mul_avx_int8_thread_func_2x2(void *args) { + int i, j, k; + struct thread_args *thread_args = (struct thread_args *)args; + const struct matmul_params *params = thread_args->params; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + int start_i = thread_args->start_i, end_i = thread_args->end_i; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + assert((end_i - start_i) % 2 == 0); + + for (i = start_i; i < end_i; i += 2) + + for (j = 0; j < C->column; j += 2) { + // (i, j), (i, j+1), (i+1, j), (i+1, j+1) + int acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(), acc1_8x32 = _mm256_setzero_si256(), + acc2_8x32 = _mm256_setzero_si256(), acc3_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 64) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]), + aa2 = _mm256_loadu_si256((const __m256i_u *)(&data_A[i * A->column + k + 32])); + __m256i cc = _mm256_loadu_si256((const __m256i_u *)&data_A[(i + 1) * A->column + k]), + cc2 = _mm256_loadu_si256((const __m256i_u *)(&data_A[(i + 1) * A->column + k + 32])); + // assume B is transposed + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]), + bb2 = _mm256_loadu_si256((const __m256i_u *)(&data_B[j * B->row + k + 32])); + __m256i dd = _mm256_loadu_si256((const __m256i_u *)&data_B[(j + 1) * B->row + k]), + dd2 = _mm256_loadu_si256((const __m256i_u *)(&data_B[(j + 1) * B->row + k + 32])); + + multiply_signed_int8_2x2(aa, bb, aa2, bb2, cc, cc2, dd, dd2, acc0_8x32, acc1_8x32, acc2_8x32, + acc3_8x32); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc0 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7] + + params->bias.int32_data_ptr[j]; + accptr = (int32_t *)&acc1_8x32; + acc1 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7] + + params->bias.int32_data_ptr[j + 1]; + accptr = (int32_t *)&acc2_8x32; + acc2 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7] + + params->bias.int32_data_ptr[j]; + accptr = (int32_t *)&acc3_8x32; + acc3 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7] + + params->bias.int32_data_ptr[j + 1]; + + acc0 = (int32_t)std::round((float)acc0 * effective_scale); + acc1 = (int32_t)std::round((float)acc1 * effective_scale); + acc2 = (int32_t)std::round((float)acc2 * effective_scale); + acc3 = (int32_t)std::round((float)acc3 * effective_scale); + + acc0 -= C_zp; + acc1 -= C_zp; + acc2 -= C_zp; + acc3 -= C_zp; + + acc0 = MAX(acc0, q_min); + acc1 = MAX(acc1, q_min); + acc2 = MAX(acc2, q_min); + acc3 = MAX(acc3, q_min); + acc0 = MIN(acc0, q_max); + acc1 = MIN(acc1, q_max); + acc2 = MIN(acc2, q_max); + acc3 = MIN(acc3, q_max); + data_C[i * C->column + j] = (int8_t)acc0; + data_C[i * C->column + j + 1] = (int8_t)acc1; + data_C[(i + 1) * C->column + j] = (int8_t)acc2; + data_C[(i + 1) * C->column + j + 1] = (int8_t)acc3; + } + return NULL; +} + +void *mat_mul_avx_int8_thread_func_2x2_32unroll(void *args) { + int i, j, k; + struct thread_args *thread_args = (struct thread_args *)args; + const struct matmul_params *params = thread_args->params; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + int start_i = thread_args->start_i, end_i = thread_args->end_i; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + assert((end_i - start_i) % 2 == 0); + + for (i = start_i; i < end_i; i += 2) + + for (j = 0; j < C->column; j += 2) { + // (i, j), (i, j+1), (i+1, j), (i+1, j+1) + int acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(), acc1_8x32 = _mm256_setzero_si256(), + acc2_8x32 = _mm256_setzero_si256(), acc3_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 32) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]); + __m256i cc = _mm256_loadu_si256((const __m256i_u *)&data_A[(i + 1) * A->column + k]); + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]); + __m256i dd = _mm256_loadu_si256((const __m256i_u *)&data_B[(j + 1) * B->row + k]); + + // multiply_signed_int8_2x2_32epi_of(aa, bb, cc, dd, acc0_8x32, acc1_8x32, acc2_8x32, acc3_8x32); + multiply_signed_int8_2x2_32epi(aa, bb, cc, dd, acc0_8x32, acc1_8x32, acc2_8x32, acc3_8x32); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc0 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc0 += params->bias.int32_data_ptr[j]; + accptr = (int32_t *)&acc1_8x32; + acc1 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc1 += params->bias.int32_data_ptr[j + 1]; + accptr = (int32_t *)&acc2_8x32; + acc2 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc2 += params->bias.int32_data_ptr[j]; + accptr = (int32_t *)&acc3_8x32; + acc3 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc3 += params->bias.int32_data_ptr[j + 1]; + + acc0 = (int32_t)std::round((float)acc0 * effective_scale); + acc1 = (int32_t)std::round((float)acc1 * effective_scale); + acc2 = (int32_t)std::round((float)acc2 * effective_scale); + acc3 = (int32_t)std::round((float)acc3 * effective_scale); + + acc0 -= C_zp; + acc1 -= C_zp; + acc2 -= C_zp; + acc3 -= C_zp; + + acc0 = MAX(acc0, q_min); + acc1 = MAX(acc1, q_min); + acc2 = MAX(acc2, q_min); + acc3 = MAX(acc3, q_min); + acc0 = MIN(acc0, q_max); + acc1 = MIN(acc1, q_max); + acc2 = MIN(acc2, q_max); + acc3 = MIN(acc3, q_max); + data_C[i * C->column + j] = (int8_t)acc0; + data_C[i * C->column + j + 1] = (int8_t)acc1; + data_C[(i + 1) * C->column + j] = (int8_t)acc2; + data_C[(i + 1) * C->column + j + 1] = (int8_t)acc3; + } + return NULL; +} + +void MatmulOperator::mat_mul_avx_int8_fast_2x2_32unroll(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + assert(params->A.column % 64 == 0); + assert((params->C.column) % 2 == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (params->C.row / num_thread); + threads_args[j].end_i = (j + 1) * (params->C.row / num_thread); + threads_args[j].blk_size = params->opt_params.blk_size; + threads_args[j].params = params; + pthread_create(&thread_pool[j], NULL, mat_mul_avx_int8_thread_func_2x2_32unroll, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +void *mat_mul_avx_int8_thread_func_2x2_32unroll_nobias(void *args) { + int i, j, k; + struct thread_args *thread_args = (struct thread_args *)args; + const struct matmul_params *params = thread_args->params; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + int start_i = thread_args->start_i, end_i = thread_args->end_i; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + assert((end_i - start_i) % 2 == 0); + + for (i = start_i; i < end_i; i += 2) + + for (j = 0; j < C->column; j += 2) { + // (i, j), (i, j+1), (i+1, j), (i+1, j+1) + int acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(), acc1_8x32 = _mm256_setzero_si256(), + acc2_8x32 = _mm256_setzero_si256(), acc3_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 32) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]); + __m256i cc = _mm256_loadu_si256((const __m256i_u *)&data_A[(i + 1) * A->column + k]); + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]); + __m256i dd = _mm256_loadu_si256((const __m256i_u *)&data_B[(j + 1) * B->row + k]); + + // multiply_signed_int8_2x2_32epi_of(aa, bb, cc, dd, acc0_8x32, acc1_8x32, acc2_8x32, acc3_8x32); + multiply_signed_int8_2x2_32epi(aa, bb, cc, dd, acc0_8x32, acc1_8x32, acc2_8x32, acc3_8x32); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc0 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc1_8x32; + acc1 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc2_8x32; + acc2 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc3_8x32; + acc3 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + + acc0 = (int32_t)std::round((float)acc0 * effective_scale); + acc1 = (int32_t)std::round((float)acc1 * effective_scale); + acc2 = (int32_t)std::round((float)acc2 * effective_scale); + acc3 = (int32_t)std::round((float)acc3 * effective_scale); + + acc0 -= C_zp; + acc1 -= C_zp; + acc2 -= C_zp; + acc3 -= C_zp; + + acc0 = MAX(acc0, q_min); + acc1 = MAX(acc1, q_min); + acc2 = MAX(acc2, q_min); + acc3 = MAX(acc3, q_min); + acc0 = MIN(acc0, q_max); + acc1 = MIN(acc1, q_max); + acc2 = MIN(acc2, q_max); + acc3 = MIN(acc3, q_max); + data_C[i * C->column + j] = (int8_t)acc0; + data_C[i * C->column + j + 1] = (int8_t)acc1; + data_C[(i + 1) * C->column + j] = (int8_t)acc2; + data_C[(i + 1) * C->column + j + 1] = (int8_t)acc3; + } + return NULL; +} + +void MatmulOperator::mat_mul_avx_int8_fast_2x2_32unroll_nobias(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + assert(params->A.column % 64 == 0); + assert((params->C.column) % 2 == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (params->C.row / num_thread); + threads_args[j].end_i = (j + 1) * (params->C.row / num_thread); + threads_args[j].blk_size = params->opt_params.blk_size; + threads_args[j].params = params; + pthread_create(&thread_pool[j], NULL, mat_mul_avx_int8_thread_func_2x2_32unroll_nobias, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +void *mat_mul_avx_int8_thread_func_2x2_32unroll_nobias_ofp32(void *args) { + int i, j, k; + struct thread_args *thread_args = (struct thread_args *)args; + const struct matmul_params *params = thread_args->params; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr; + float *data_C = C->data_ptr; + int start_i = thread_args->start_i, end_i = thread_args->end_i; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + assert((end_i - start_i) % 2 == 0); + + for (i = start_i; i < end_i; i += 2) + + for (j = 0; j < C->column; j += 2) { + // (i, j), (i, j+1), (i+1, j), (i+1, j+1) + int acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(), acc1_8x32 = _mm256_setzero_si256(), + acc2_8x32 = _mm256_setzero_si256(), acc3_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 32) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]); + __m256i cc = _mm256_loadu_si256((const __m256i_u *)&data_A[(i + 1) * A->column + k]); + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]); + __m256i dd = _mm256_loadu_si256((const __m256i_u *)&data_B[(j + 1) * B->row + k]); + + multiply_signed_int8_2x2_32epi(aa, bb, cc, dd, acc0_8x32, acc1_8x32, acc2_8x32, acc3_8x32); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc0 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc1_8x32; + acc1 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc2_8x32; + acc2 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc3_8x32; + acc3 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + + data_C[i * C->column + j] = ((float)acc0 * effective_scale); + data_C[i * C->column + j + 1] = ((float)acc1 * effective_scale); + data_C[(i + 1) * C->column + j] = ((float)acc2 * effective_scale); + data_C[(i + 1) * C->column + j + 1] = ((float)acc3 * effective_scale); + } + return NULL; +} + +void MatmulOperator::mat_mul_avx_int8_fast_2x2_32unroll_nobias_ofp32(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + assert(params->A.column % 64 == 0); + assert((params->C.column) % 2 == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (params->C.row / num_thread); + threads_args[j].end_i = (j + 1) * (params->C.row / num_thread); + threads_args[j].blk_size = params->opt_params.blk_size; + threads_args[j].params = params; + pthread_create(&thread_pool[j], NULL, mat_mul_avx_int8_thread_func_2x2_32unroll_nobias_ofp32, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +void *mat_mul_avx_int8_thread_func_2x2_32unroll_bfp32_ofp32(void *args) { + int i, j, k; + struct thread_args *thread_args = (struct thread_args *)args; + const struct matmul_params *params = thread_args->params; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr; + float *data_C = C->data_ptr; + int start_i = thread_args->start_i, end_i = thread_args->end_i; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + assert((end_i - start_i) % 2 == 0); + + for (i = start_i; i < end_i; i += 2) + + for (j = 0; j < C->column; j += 2) { + // (i, j), (i, j+1), (i+1, j), (i+1, j+1) + int acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(), acc1_8x32 = _mm256_setzero_si256(), + acc2_8x32 = _mm256_setzero_si256(), acc3_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 32) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]); + __m256i cc = _mm256_loadu_si256((const __m256i_u *)&data_A[(i + 1) * A->column + k]); + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]); + __m256i dd = _mm256_loadu_si256((const __m256i_u *)&data_B[(j + 1) * B->row + k]); + + multiply_signed_int8_2x2_32epi(aa, bb, cc, dd, acc0_8x32, acc1_8x32, acc2_8x32, acc3_8x32); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc0 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc1_8x32; + acc1 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc2_8x32; + acc2 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + accptr = (int32_t *)&acc3_8x32; + acc3 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + + data_C[i * C->column + j] = ((float)acc0 * effective_scale) + params->bias.data_ptr[j]; + data_C[i * C->column + j + 1] = ((float)acc1 * effective_scale) + params->bias.data_ptr[j + 1]; + data_C[(i + 1) * C->column + j] = ((float)acc2 * effective_scale) + params->bias.data_ptr[j]; + data_C[(i + 1) * C->column + j + 1] = ((float)acc3 * effective_scale) + params->bias.data_ptr[j + 1]; + } + return NULL; +} + +void MatmulOperator::mat_mul_avx_int8_fast_2x2_32unroll_bfp32_ofp32(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + assert(params->A.column % 64 == 0); + assert((params->C.column) % 2 == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (params->C.row / num_thread); + threads_args[j].end_i = (j + 1) * (params->C.row / num_thread); + threads_args[j].blk_size = params->opt_params.blk_size; + threads_args[j].params = params; + pthread_create(&thread_pool[j], NULL, mat_mul_avx_int8_thread_func_2x2_32unroll_bfp32_ofp32, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +void MatmulOperator::mat_mul_avx_int8_fast_2x2(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + assert(params->A.column % 64 == 0); + assert((params->C.column) % 2 == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (params->C.row / num_thread); + threads_args[j].end_i = (j + 1) * (params->C.row / num_thread); + threads_args[j].blk_size = params->opt_params.blk_size; + threads_args[j].params = params; + pthread_create(&thread_pool[j], NULL, mat_mul_avx_int8_thread_func_2x2, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +void MatmulOperator::mat_mul_avx_int8_fast_2x2_omp(const struct matmul_params *params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + + assert(A->column % 64 == 0); + assert((C->column) % 2 == 0); + assert((C->row) % 2 == 0); + + // #pragma omp parallel for + for (i = 0; i < C->row; i += 2) + for (j = 0; j < C->column; j += 2) { + // (i, j), (i, j+1), (i+1, j), (i+1, j+1) + int acc0 = 0, acc1 = 0, acc2 = 0, acc3 = 0; + __m256i acc0_8x32 = _mm256_setzero_si256(), acc1_8x32 = _mm256_setzero_si256(), + acc2_8x32 = _mm256_setzero_si256(), acc3_8x32 = _mm256_setzero_si256(); + for (k = 0; k < A->column; k += 64) { + __m256i aa = _mm256_loadu_si256((const __m256i_u *)&data_A[i * A->column + k]), + aa2 = _mm256_loadu_si256((const __m256i_u *)(&data_A[i * A->column + k + 32])); + __m256i cc = _mm256_loadu_si256((const __m256i_u *)&data_A[(i + 1) * A->column + k]), + cc2 = _mm256_loadu_si256((const __m256i_u *)(&data_A[(i + 1) * A->column + k + 32])); + // assume B is transposed + __m256i bb = _mm256_loadu_si256((const __m256i_u *)&data_B[j * B->row + k]), + bb2 = _mm256_loadu_si256((const __m256i_u *)(&data_B[j * B->row + k + 32])); + __m256i dd = _mm256_loadu_si256((const __m256i_u *)&data_B[(j + 1) * B->row + k]), + dd2 = _mm256_loadu_si256((const __m256i_u *)(&data_B[(j + 1) * B->row + k + 32])); + + multiply_signed_int8_2x2(aa, bb, aa2, bb2, cc, cc2, dd, dd2, acc0_8x32, acc1_8x32, acc2_8x32, + acc3_8x32); + } + int32_t *accptr = (int32_t *)&acc0_8x32; + acc0 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc0 += params->bias.int32_data_ptr[j]; + accptr = (int32_t *)&acc1_8x32; + acc1 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc1 += params->bias.int32_data_ptr[j + 1]; + accptr = (int32_t *)&acc2_8x32; + acc2 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc2 += params->bias.int32_data_ptr[j]; + accptr = (int32_t *)&acc3_8x32; + acc3 = accptr[0] + accptr[1] + accptr[2] + accptr[3] + accptr[4] + accptr[5] + accptr[6] + accptr[7]; + acc3 += params->bias.int32_data_ptr[j + 1]; + + acc0 = (int32_t)std::round((float)acc0 * effective_scale); + acc1 = (int32_t)std::round((float)acc1 * effective_scale); + acc2 = (int32_t)std::round((float)acc2 * effective_scale); + acc3 = (int32_t)std::round((float)acc3 * effective_scale); + + acc0 -= C_zp; + acc1 -= C_zp; + acc2 -= C_zp; + acc3 -= C_zp; + + acc0 = MAX(acc0, q_min); + acc1 = MAX(acc1, q_min); + acc2 = MAX(acc2, q_min); + acc3 = MAX(acc3, q_min); + acc0 = MIN(acc0, q_max); + acc1 = MIN(acc1, q_max); + acc2 = MIN(acc2, q_max); + acc3 = MIN(acc3, q_max); + data_C[i * C->column + j] = (int8_t)acc0; + data_C[i * C->column + j + 1] = (int8_t)acc1; + data_C[(i + 1) * C->column + j] = (int8_t)acc2; + data_C[(i + 1) * C->column + j + 1] = (int8_t)acc3; + } +} + +} // namespace matmul + +// void initialize_vector(int8_t A[], int size) { +// for (int i = 0; i < size; i++) { +// // A[i] = (rand() % 2) - 1; +// A[i] = (rand() % 254) - 127; +// } +// } + +// int main(){ +// int8_t A[64], B[64]; +// initialize_vector(A, 64); +// initialize_vector(B, 64); + +// int32_t ref_acc = 0, acc; +// for (int i = 0; i < 64; i++){ +// ref_acc += A[i] * B[i]; +// } + +// __m256i aa = _mm256_loadu_si256((const __m256i_u *)A), bb = _mm256_loadu_si256((const __m256i_u *)B); +// __m256i aa2 = _mm256_loadu_si256((const __m256i_u *)(&A[32])), bb2 = _mm256_loadu_si256((const __m256i_u +// *)(&B[32])); + +// __m256i acc0_8x32 = multiply_signed_int8(aa, bb, aa2, bb2); +// int32_t *accptr = (int32_t*)&acc0_8x32; +// acc = accptr[0] + accptr[1] + accptr[2] + accptr[3]+ accptr[4] + accptr[5] + accptr[6] + accptr[7]; + +// printf("%d, %d\n", acc, ref_acc); +// assert(acc == ref_acc); + +// return 0; +// } diff --git a/experimental/matmul_optimization/src/lib/matmul_imp.cc b/experimental/matmul_optimization/src/lib/matmul_imp.cc new file mode 100644 index 00000000..6a5d4fc5 --- /dev/null +++ b/experimental/matmul_optimization/src/lib/matmul_imp.cc @@ -0,0 +1,446 @@ +#ifndef MATMUL_H_ +#define MATMUL_H_ + +#include +#include +#include +#include +// #include +#include // intel SSE intrinsic + +#include +#include + +#include "matmul.h" + +#define MAX_TRANSPOSE_BUFFER 2048 * 20480 +#define RUNS 1 + +float transpose_tmp[MAX_TRANSPOSE_BUFFER]; + +namespace matmul { + +void MatmulOperator::CHECK_MATRICES(const struct matrix *A, const struct matrix *B, const struct matrix *C) { + assert(A->column == B->row); + assert(C->column == B->column); + assert(C->row == A->row); +} + +inline void simd_mul_fp_128(const float *a, const float *b, float *c) { + __m128 val = _mm_mul_ps(_mm_load_ps(a), _mm_load_ps(b)); + __m128 acc = _mm_add_ps(_mm_load_ps(c), val); + _mm_store_ps(c, acc); +} + +void MatmulOperator::naive_mat_mul(const struct matmul_params *params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + CHECK_MATRICES(A, B, C); + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j++) { + float acc = 0; + for (k = 0; k < A->column; k++) acc += data_A[i * A->column + k] * data_B[k * B->column + j]; + data_C[i * C->column + j] = acc; + } +} + +void MatmulOperator::mat_mul_unrolling(const struct matmul_params *params) { + int i, j, k; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + CHECK_MATRICES(A, B, C); + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j += 8) { + float acc0 = 0; + float acc1 = 0; + float acc2 = 0; + float acc3 = 0; + float acc4 = 0; + float acc5 = 0; + float acc6 = 0; + float acc7 = 0; + for (k = 0; k < A->column; k += 4) { + float Aik0 = data_A[i * A->column + k]; + float Aik1 = data_A[i * A->column + k + 1]; + float Aik2 = data_A[i * A->column + k + 2]; + float Aik3 = data_A[i * A->column + k + 3]; + + acc0 += Aik0 * data_B[k * B->column + j]; + acc0 += Aik1 * data_B[(k + 1) * B->column + j]; + acc0 += Aik2 * data_B[(k + 2) * B->column + j]; + acc0 += Aik3 * data_B[(k + 3) * B->column + j]; + + acc1 += Aik0 * data_B[k * B->column + j + 1]; + acc1 += Aik1 * data_B[(k + 1) * B->column + j + 1]; + acc1 += Aik2 * data_B[(k + 2) * B->column + j + 1]; + acc1 += Aik3 * data_B[(k + 3) * B->column + j + 1]; + + acc2 += Aik0 * data_B[k * B->column + j + 2]; + acc2 += Aik1 * data_B[(k + 1) * B->column + j + 2]; + acc2 += Aik2 * data_B[(k + 2) * B->column + j + 2]; + acc2 += Aik3 * data_B[(k + 3) * B->column + j + 2]; + + acc3 += Aik0 * data_B[k * B->column + j + 3]; + acc3 += Aik1 * data_B[(k + 1) * B->column + j + 3]; + acc3 += Aik2 * data_B[(k + 2) * B->column + j + 3]; + acc3 += Aik3 * data_B[(k + 3) * B->column + j + 3]; + + acc4 += Aik0 * data_B[k * B->column + j + 4]; + acc4 += Aik1 * data_B[(k + 1) * B->column + j + 4]; + acc4 += Aik2 * data_B[(k + 2) * B->column + j + 4]; + acc4 += Aik3 * data_B[(k + 3) * B->column + j + 4]; + + acc5 += Aik0 * data_B[k * B->column + j + 5]; + acc5 += Aik1 * data_B[(k + 1) * B->column + j + 5]; + acc5 += Aik2 * data_B[(k + 2) * B->column + j + 5]; + acc5 += Aik3 * data_B[(k + 3) * B->column + j + 5]; + + acc6 += Aik0 * data_B[k * B->column + j + 6]; + acc6 += Aik1 * data_B[(k + 1) * B->column + j + 6]; + acc6 += Aik2 * data_B[(k + 2) * B->column + j + 6]; + acc6 += Aik3 * data_B[(k + 3) * B->column + j + 6]; + + acc7 += Aik0 * data_B[k * B->column + j + 7]; + acc7 += Aik1 * data_B[(k + 1) * B->column + j + 7]; + acc7 += Aik2 * data_B[(k + 2) * B->column + j + 7]; + acc7 += Aik3 * data_B[(k + 3) * B->column + j + 7]; + } + data_C[i * C->column + j] = acc0; + data_C[i * C->column + j + 1] = acc1; + data_C[i * C->column + j + 2] = acc2; + data_C[i * C->column + j + 3] = acc3; + data_C[i * C->column + j + 4] = acc4; + data_C[i * C->column + j + 5] = acc5; + data_C[i * C->column + j + 6] = acc6; + data_C[i * C->column + j + 7] = acc7; + } +} + +void MatmulOperator::mat_mul_reordering(const struct matmul_params *params) { + int i, j, k; + float Aik; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + CHECK_MATRICES(A, B, C); + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j++) data_C[i * C->column + j] = 0; + + for (i = 0; i < C->row; i++) + for (k = 0; k < A->column; k++) { + Aik = data_A[i * A->column + k]; + for (j = 0; j < C->column; j++) { + data_C[i * C->column + j] += Aik * data_B[k * B->column + j]; + } + } +} + +void MatmulOperator::mat_mul_tiling(const struct matmul_params *params) { + int i, j, k, ti, tj, tk; + float Aik; + + int BLK_SIZE = params->opt_params.blk_size; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + CHECK_MATRICES(A, B, C); + assert(C->row % BLK_SIZE == 0); + assert(A->column % BLK_SIZE == 0); + assert(C->column % BLK_SIZE == 0); + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j++) data_C[i * C->column + j] = 0; + + for (ti = 0; ti < C->row; ti += BLK_SIZE) { + for (tk = 0; tk < A->column; tk += BLK_SIZE) { + for (tj = 0; tj < C->column; tj += BLK_SIZE) { + for (i = ti; i < ti + BLK_SIZE; i++) + for (k = tk; k < tk + BLK_SIZE; k++) { + Aik = data_A[i * A->column + k]; + for (j = tj; j < tj + BLK_SIZE; j++) { + data_C[i * C->column + j] += Aik * data_B[k * B->column + j]; + } + } + } + } + } +} + +/* This function assume legal matrices */ +void *thread_func(void *args) { + struct thread_args *mat_args = (struct thread_args *)args; + const struct matrix *A = mat_args->A; + const struct matrix *B = mat_args->B; + const struct matrix *C = mat_args->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + int start_i = mat_args->start_i, end_i = mat_args->end_i; + + for (int i = start_i; i < end_i; i++) + for (int j = 0; j < C->column; j++) { + float acc = 0; + for (int k = 0; k < A->column; k++) acc += data_A[i * A->column + k] * data_B[k * B->column + j]; + data_C[i * C->column + j] = acc; + } + + return NULL; +} + +void MatmulOperator::mat_mul_multithreading(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + CHECK_MATRICES(A, B, C); + assert(num_thread != 0); + assert(C->row % num_thread == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (C->row / num_thread); + threads_args[j].end_i = (j + 1) * (C->row / num_thread); + threads_args[j].A = A; + threads_args[j].B = B; + threads_args[j].C = C; + pthread_create(&thread_pool[j], NULL, thread_func, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +void MatmulOperator::mat_mul_transpose(const struct matmul_params *params) { + int i, j, k; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + CHECK_MATRICES(A, B, C); + + // transpose the B + for (i = 0; i < B->column; i++) + for (j = 0; j < B->row; j++) transpose_tmp[i * B->row + j] = data_B[j * B->column + i]; + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j++) { + float acc = 0; + for (k = 0; k < A->column; k++) acc += data_A[i * A->column + k] * transpose_tmp[j * B->row + k]; + data_C[i * C->column + j] = acc; + } +} + +void MatmulOperator::mat_mul_transpose_simd(const struct matmul_params *params) { + int i, j, k; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + CHECK_MATRICES(A, B, C); + + // transpose the B + for (i = 0; i < B->column; i++) + for (j = 0; j < B->row; j++) transpose_tmp[i * B->row + j] = data_B[j * B->column + i]; + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j++) { + float accumulators[4] = {}; + for (k = 0; k < A->column; k += 4) + simd_mul_fp_128(&data_A[i * A->column + k], &transpose_tmp[j * B->row + k], accumulators); + data_C[i * C->column + j] = accumulators[0] + accumulators[1] + accumulators[2] + accumulators[3]; + } +} + +float interval_to_ms(struct timeval *start, struct timeval *end) { + float us_seconds = (end->tv_sec - start->tv_sec) * 1000000 + (end->tv_usec - start->tv_usec); + return us_seconds / 1000; +} + +void MatmulOperator::evaluate(IMP_TYPE type, const struct matmul_params *params) { + struct timeval start, end; + int ms; + std::string function_name; + + gettimeofday(&start, NULL); + // choose implementation + switch (type) { + case NAIVE: + function_name = "naive_mat_mul"; + for (int i = 0; i < RUNS; i++) this->naive_mat_mul(params); + break; + case UNROLL: + function_name = "mat_mul_unrolling"; + for (int i = 0; i < RUNS; i++) this->mat_mul_unrolling(params); + break; + case REORDER: + function_name = "mat_mul_reordering"; + for (int i = 0; i < RUNS; i++) this->mat_mul_reordering(params); + break; + case TILING: + function_name = "mat_mul_tiling"; + for (int i = 0; i < RUNS; i++) this->mat_mul_tiling(params); + break; + case MULTITHREAD: + function_name = "mat_mul_multithreading"; + for (int i = 0; i < RUNS; i++) this->mat_mul_multithreading(params); + break; + case TRANSPOSE: + function_name = "mat_mul_transpose"; + for (int i = 0; i < RUNS; i++) this->mat_mul_transpose(params); + break; + case TRANSPOSE_SIMD: + function_name = "mat_mul_transpose_simd"; + for (int i = 0; i < RUNS; i++) this->mat_mul_transpose_simd(params); + break; + case CUDA: + function_name = "mat_mul_cuda"; +#ifdef CUDA_ENABLE + for (int i = 0; i < RUNS; i++) this->mat_mul_cuda(params); +#else + fprintf(stderr, "CUDA not enable!\n"); + exit(-1); +#endif + break; + case FAST: + function_name = "mat_mul_fast"; + for (int i = 0; i < RUNS; i++) this->mat_mul_fast(params); + break; + case ONEDNN_FP32: + function_name = "mat_mul_onednn"; +#ifdef ONEDNN_ENABLE + for (int i = 0; i < RUNS; i++) this->mat_mul_onednn(params); +#else + fprintf(stderr, "ONEDNN not enable!\n"); + exit(-1); +#endif + break; + case ONEDNN_INT8: + function_name = "mat_mul_onednn_int8"; +#ifdef ONEDNN_ENABLE + for (int i = 0; i < RUNS; i++) this->mat_mul_onednn_int8(params); +#else + fprintf(stderr, "ONEDNN not enable!\n"); + exit(-1); +#endif + break; + case INT8_BASELINE: + function_name = "naive_mat_mul_int8"; + for (int i = 0; i < RUNS; i++) this->naive_mat_mul_int8(params); + break; + case INT8_AVX: + function_name = "mat_mul_avx_int8"; + for (int i = 0; i < RUNS; i++) this->mat_mul_avx_int8(params); + break; + case INT8_AVX_FAST: + function_name = "mat_mul_avx_int8_fast"; + for (int i = 0; i < RUNS; i++) this->mat_mul_avx_int8_fast(params); + break; + case INT8_AVX_FAST_2x2: + function_name = "mat_mul_avx_int8_fast_2x2"; + for (int i = 0; i < RUNS; i++) this->mat_mul_avx_int8_fast_2x2(params); + break; + case INT8_AVX_FAST_2x2_32UNROLL: + function_name = "mat_mul_avx_int8_fast_2x2_32unroll"; + for (int i = 0; i < RUNS; i++) this->mat_mul_avx_int8_fast_2x2_32unroll(params); + break; + case INT8_AVX_FAST_2x2_OMP: + function_name = "mat_mul_avx_int8_fast_2x2_omp"; + for (int i = 0; i < RUNS; i++) this->mat_mul_avx_int8_fast_2x2_omp(params); + break; + default: + break; + } + gettimeofday(&end, NULL); + ms = interval_to_ms(&start, &end); + float GOPS = + (float)((float)params->C.column * (float)params->C.row * (float)params->B.row) * 2 / (1000000000) * RUNS; + std::cout << function_name << ": " << ms << " ms, GOPS/s:" << GOPS / ((float)ms / 1000) << std::endl; +} + +void *fast_thread_func(void *args) { + struct thread_args *mat_args = (struct thread_args *)args; + const struct matrix *A = mat_args->A; + const struct matrix *B = mat_args->B; + const struct matrix *C = mat_args->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + int start_i = mat_args->start_i, end_i = mat_args->end_i; + + int BLK_SIZE = mat_args->blk_size; + assert((end_i - start_i) % BLK_SIZE == 0); + assert(A->column % BLK_SIZE == 0); + assert(C->column % BLK_SIZE == 0); + assert(BLK_SIZE % 4 == 0); + + for (int ti = start_i; ti < end_i; ti += BLK_SIZE) { + for (int tj = 0; tj < C->column; tj += BLK_SIZE) { + for (int i = ti; i < ti + BLK_SIZE; i++) + for (int j = tj; j < tj + BLK_SIZE; j += 4) { + float acc0[4] = {}, acc1[4] = {}, acc2[4] = {}, acc3[4] = {}; + __m128 *acc0_fp_128 = (__m128 *)acc0; + __m128 *acc1_fp_128 = (__m128 *)acc1; + __m128 *acc2_fp_128 = (__m128 *)acc2; + __m128 *acc3_fp_128 = (__m128 *)acc3; + + for (int k = 0; k < A->column; k += 4) { + __m128 Aik_Aik3 = _mm_load_ps(&data_A[i * A->column + k]); + __m128 val; + val = _mm_mul_ps(Aik_Aik3, _mm_load_ps(&data_B[j * B->column + k])); + *acc0_fp_128 = _mm_add_ps(*acc0_fp_128, val); + + val = _mm_mul_ps(Aik_Aik3, _mm_load_ps(&data_B[(j + 1) * B->column + k])); + *acc1_fp_128 = _mm_add_ps(*acc1_fp_128, val); + + val = _mm_mul_ps(Aik_Aik3, _mm_load_ps(&data_B[(j + 2) * B->column + k])); + *acc2_fp_128 = _mm_add_ps(*acc2_fp_128, val); + + val = _mm_mul_ps(Aik_Aik3, _mm_load_ps(&data_B[(j + 3) * B->column + k])); + *acc3_fp_128 = _mm_add_ps(*acc3_fp_128, val); + } + data_C[i * C->column + j] = acc0[0] + acc0[1] + acc0[2] + acc0[3]; + data_C[i * C->column + j + 1] = acc1[0] + acc1[1] + acc1[2] + acc1[3]; + data_C[i * C->column + j + 2] = acc2[0] + acc2[1] + acc2[2] + acc2[3]; + data_C[i * C->column + j + 3] = acc3[0] + acc3[1] + acc3[2] + acc3[3]; + } + } + } + + return NULL; +} + +void MatmulOperator::mat_mul_fast(const struct matmul_params *params) { + int j, num_thread = params->opt_params.num_thread; + + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + + assert(A->column == B->column); + assert(C->column == B->row); + assert(C->row == A->row); + assert(num_thread != 0); + assert(C->row % num_thread == 0); + + pthread_t thread_pool[num_thread]; + struct thread_args threads_args[num_thread]; + + // Thread creation + for (j = 0; j < num_thread; j++) { + threads_args[j].start_i = j * (C->row / num_thread); + threads_args[j].end_i = (j + 1) * (C->row / num_thread); + threads_args[j].blk_size = params->opt_params.blk_size; + threads_args[j].A = A; + threads_args[j].B = B; + threads_args[j].C = C; + pthread_create(&thread_pool[j], NULL, fast_thread_func, &threads_args[j]); + } + // Join threads + for (j = 0; j < num_thread; j++) { + pthread_join(thread_pool[j], NULL); + } +} + +} // namespace matmul + +#endif // MATMUL_H_ diff --git a/experimental/matmul_optimization/src/lib/matmul_int8.cc b/experimental/matmul_optimization/src/lib/matmul_int8.cc new file mode 100644 index 00000000..3d5f2ed3 --- /dev/null +++ b/experimental/matmul_optimization/src/lib/matmul_int8.cc @@ -0,0 +1,31 @@ +#include +#include + +#include "matmul.h" + +namespace matmul { + +void MatmulOperator::naive_mat_mul_int8(const struct matmul_params *params) { + int i, j, k; + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int32_t A_zp = A->qparams.zero_point, C_zp = C->qparams.zero_point; + float A_sc = A->qparams.scale, B_sc = B->qparams.scale, C_sc = C->qparams.scale; + float effective_scale = A_sc * B_sc / C_sc; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + const int8_t q_min = C->qparams.q_min, q_max = C->qparams.q_max; + CHECK_MATRICES(A, B, C); + + for (i = 0; i < C->row; i++) + for (j = 0; j < C->column; j++) { + int acc = 0; + for (k = 0; k < A->column; k++) + acc += ((int32_t)data_A[i * A->column + k] - A_zp) * data_B[k * B->column + j]; + + acc = (int32_t)((float)acc * effective_scale); + acc -= C_zp; + acc = MAX(acc, q_min); + acc = MIN(acc, q_max); + data_C[i * C->column + j] = (int8_t)acc; + } +} +} // namespace matmul diff --git a/experimental/matmul_optimization/src/lib/matmul_onednn.cc b/experimental/matmul_optimization/src/lib/matmul_onednn.cc new file mode 100644 index 00000000..444b5b83 --- /dev/null +++ b/experimental/matmul_optimization/src/lib/matmul_onednn.cc @@ -0,0 +1,145 @@ +#ifdef ONEDNN_ENABLE +//#define DUMP_KERNEL_TIME +#include +#include + +#include "matmul.h" +#include "oneapi/dnnl/dnnl.hpp" + +namespace matmul { +// void assign_data() + +inline void write_to_dnnl_memory(void *handle, dnnl::memory &mem) { + dnnl::engine eng = mem.get_engine(); + size_t size = mem.get_desc().get_size(); + + if (!handle) throw std::runtime_error("handle is nullptr."); + + if (eng.get_kind() == dnnl::engine::kind::cpu) { + uint8_t *dst = static_cast(mem.get_data_handle()); + if (!dst) throw std::runtime_error("get_data_handle returned nullptr."); + for (size_t i = 0; i < size; ++i) dst[i] = ((uint8_t *)handle)[i]; + return; + } + + assert(!"not expected"); +} + +inline void read_from_dnnl_memory(void *handle, dnnl::memory &mem) { + dnnl::engine eng = mem.get_engine(); + size_t size = mem.get_desc().get_size(); + + if (eng.get_kind() == dnnl::engine::kind::cpu) { + uint8_t *src = static_cast(mem.get_data_handle()); + if (!src) throw std::runtime_error("get_data_handle returned nullptr."); + for (size_t i = 0; i < size; ++i) ((uint8_t *)handle)[i] = src[i]; + return; + } +} + +void MatmulOperator::mat_mul_onednn(const struct matmul_params *params) { + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + float *data_A = A->data_ptr, *data_B = B->data_ptr, *data_C = C->data_ptr; + CHECK_MATRICES(A, B, C); + + int M = A->row, N = B->column, K = A->column; + + // Initialize description for matmul + dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::f32, {K, 1}); // M x K layout + dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::f32, {N, 1}); // K X N layout + dnnl::memory::desc c_md({M, N}, dnnl::memory::data_type::f32, {N, 1}); // M x N layout + + dnnl::engine eng(dnnl::engine::kind::cpu, 0); + auto matmul_desc = dnnl::matmul::primitive_desc(eng, a_md, b_md, c_md); + + // Initialize memory + dnnl::memory A_fp_mem(matmul_desc.src_desc(), eng, (void *)data_A); + dnnl::memory B_fp_mem(matmul_desc.weights_desc(), eng, (void *)data_B); + dnnl::memory C_fp_mem(matmul_desc.dst_desc(), eng, (void *)data_C); + + // Operator + dnnl::matmul matmul_p(matmul_desc); + + struct timeval start, end; + gettimeofday(&start, NULL); + dnnl::stream s(eng); + matmul_p.execute(s, {{DNNL_ARG_SRC, A_fp_mem}, {DNNL_ARG_WEIGHTS, B_fp_mem}, {DNNL_ARG_DST, C_fp_mem}}); + + s.wait(); + gettimeofday(&end, NULL); + int us = interval_to_us(&start, &end); +#ifdef INT8_AVX_FAST + std::cout << "onednn kernel: " << us / 1000 << " ms" << std::endl; +#endif +} + +void MatmulOperator::mat_mul_onednn_int8(const struct matmul_params *params) { + const struct matrix *A = ¶ms->A, *B = ¶ms->B, *C = ¶ms->C; + int8_t *data_A = A->int8_data_ptr, *data_B = B->int8_data_ptr, *data_C = C->int8_data_ptr; + CHECK_MATRICES(A, B, C); + + int M = A->row, N = B->column, K = A->column; + + // Initialize description for matmul + dnnl::memory::desc a_md({M, K}, dnnl::memory::data_type::s8, {K, 1}); // M x K layout + dnnl::memory::desc b_md({K, N}, dnnl::memory::data_type::s8, {N, 1}); // K X N layout + dnnl::memory::desc c_md({M, N}, dnnl::memory::data_type::s8, {N, 1}); // M x N layout + + dnnl::engine eng(dnnl::engine::kind::cpu, 0); + dnnl::primitive_attr attr; + attr.set_scales_mask(DNNL_ARG_SRC, /* mask */ 0); + attr.set_scales_mask(DNNL_ARG_WEIGHTS, /* mask */ 1 << 1); + attr.set_scales_mask(DNNL_ARG_DST, /* mask */ 0); + attr.set_zero_points_mask(DNNL_ARG_SRC, /* mask */ 0); + attr.set_zero_points_mask(DNNL_ARG_DST, /* mask */ 0); + auto matmul_desc = dnnl::matmul::primitive_desc(eng, a_md, b_md, c_md, attr); + + // Initialize memory + dnnl::memory A_mem(matmul_desc.src_desc(), eng, (void *)data_A); + dnnl::memory B_mem(matmul_desc.weights_desc(), eng, (void *)data_B); + dnnl::memory C_mem(matmul_desc.dst_desc(), eng, (void *)data_C); + + dnnl::memory zp_A_mem({{1}, dnnl::memory::data_type::s32, {1}}, eng); + dnnl::memory zp_C_mem({{1}, dnnl::memory::data_type::s32, {1}}, eng); + dnnl::memory sc_A_mem({{1}, dnnl::memory::data_type::f32, {1}}, eng); + dnnl::memory sc_B_mem({{N}, dnnl::memory::data_type::f32, {1}}, eng); + dnnl::memory sc_C_mem({{1}, dnnl::memory::data_type::f32, {1}}, eng); + + // Assign zero points + int32_t *zp_handle = static_cast(zp_A_mem.get_data_handle()); + *zp_handle = A->qparams.zero_point; + zp_handle = static_cast(zp_C_mem.get_data_handle()); + *zp_handle = C->qparams.zero_point; + + // Assign scales + float *sc_handle = static_cast(sc_A_mem.get_data_handle()); + *sc_handle = A->qparams.scale; + sc_handle = static_cast(sc_B_mem.get_data_handle()); + for (int i = 0; i < N; i++) sc_handle[i] = B->qparams.scale; + sc_handle = static_cast(sc_C_mem.get_data_handle()); + *sc_handle = C->qparams.scale; + + // Operator + dnnl::matmul matmul_p(matmul_desc); + + struct timeval start, end; + gettimeofday(&start, NULL); + dnnl::stream s(eng); + matmul_p.execute(s, {{DNNL_ARG_SRC, A_mem}, + {DNNL_ARG_WEIGHTS, B_mem}, + {DNNL_ARG_DST, C_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC, sc_A_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_WEIGHTS, sc_B_mem}, + {DNNL_ARG_ATTR_SCALES | DNNL_ARG_DST, sc_C_mem}, + {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC, zp_A_mem}, + {DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_DST, zp_C_mem}}); + s.wait(); + gettimeofday(&end, NULL); + int us = interval_to_us(&start, &end); +#ifdef INT8_AVX_FAST + std::cout << "onednn kernel: " << us / 1000 << " ms" << std::endl; +#endif +} + +} // namespace matmul +#endif diff --git a/experimental/matmul_optimization/src/lib/utils.cc b/experimental/matmul_optimization/src/lib/utils.cc new file mode 100644 index 00000000..74513d43 --- /dev/null +++ b/experimental/matmul_optimization/src/lib/utils.cc @@ -0,0 +1,12 @@ +#include + +#include "matmul.h" + +namespace matmul { + +float MatmulOperator::interval_to_us(struct timeval *start, struct timeval *end) { + float us_seconds = (end->tv_sec - start->tv_sec) * 1000000 + (end->tv_usec - start->tv_usec); + return us_seconds; +} + +} // namespace matmul