Skip to content

Commit

Permalink
Support StarCoder on CPU (#78)
Browse files Browse the repository at this point in the history
  • Loading branch information
RaymondWang0 authored Nov 22, 2023
1 parent ab90cc4 commit c3b94c7
Show file tree
Hide file tree
Showing 29 changed files with 2,310 additions and 11 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

assets/
*.bin
!llama_vocab.bin
!starcoder_vocab.bin
*.zip
*.txt
!requirements.txt
Expand Down
2 changes: 1 addition & 1 deletion kernels/avx/matmul_avx_int8_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_
}
threads_args[j].params = params;
// pthread_create(&thread_pool[j], NULL, fast_int8_int4_zp_no_offset_over_column_unroll2block, &threads_args[j]);
pool_enqueue(pool, &threads_args[j], NULL);
pool_enqueue(pool, &threads_args[j], '\0');
}
// // Join threads
// for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL);
Expand Down
8 changes: 5 additions & 3 deletions kernels/neon/matmul_neon_int8_int4.cc
Original file line number Diff line number Diff line change
Expand Up @@ -357,12 +357,14 @@ inline static void* matmul_int8_int4_no_offset_over_column_unroll128(void* args)
sumv2 = vmlaq_n_f32(sumv2, vcvtq_f32_s32(int_sum2), s_2);
sumv3 = vmlaq_n_f32(sumv3, vcvtq_f32_s32(int_sum3), s_3);
}
if (params->bias.data_ptr)
if (params->bias.data_ptr) {
params->C.data_ptr[i * n + j] = params->bias.data_ptr[j] + vaddvq_f32(sumv0) + vaddvq_f32(sumv1) +
vaddvq_f32(sumv2) + vaddvq_f32(sumv3);
else
}
else {
params->C.data_ptr[i * n + j] =
vaddvq_f32(sumv0) + vaddvq_f32(sumv1) + vaddvq_f32(sumv2) + vaddvq_f32(sumv3);
}
}
}

Expand Down Expand Up @@ -586,7 +588,7 @@ void MatmulOperator::mat_mul_accelerator_int8_int4_fast_no_offset(struct matmul_
// #else
// pthread_create(&thread_pool[j], NULL, matmul_int8_int4_no_offset_over_column_unroll128, &threads_args[j]);
// #endif
pool_enqueue(pool, &threads_args[j], NULL);
pool_enqueue(pool, &threads_args[j], '\0');
}
// Join threads
// for (j = 0; j < num_thread; j++) pthread_join(thread_pool[j], NULL);
Expand Down
78 changes: 75 additions & 3 deletions llm/application/chat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@
std::map<std::string, int> model_config = {
{"OPT_125m", OPT_125M}, {"OPT_1.3B", OPT_1_3B}, {"OPT_6.7B", OPT_6_7B}, {"LLaMA_7B", LLaMA_7B},
{"LLaMA2_7B_chat", LLaMA_7B}, {"7b", LLaMA_7B}, {"LLaMA2_13B_chat", LLaMA_13B}, {"13b", LLaMA_13B},
{"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B}};
{"CodeLLaMA_7B_Instruct", CodeLLaMA_7B}, {"CodeLLaMA_13B_Instruct", CodeLLaMA_13B},
{"StarCoder", StarCoder_15_5B}, {"StarCoder_15.5B", StarCoder_15_5B}
};

std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"},
{"OPT_1.3B", "models/OPT_1.3B"},
Expand All @@ -18,7 +20,10 @@ std::map<std::string, std::string> model_path = {{"OPT_125m", "models/OPT_125m"}
{"7b", "models/LLaMA_7B_2_chat"},
{"13b", "models/LLaMA_13B_2_chat"},
{"CodeLLaMA_7B_Instruct", "models/CodeLLaMA_7B_Instruct"},
{"CodeLLaMA_13B_Instruct", "models/CodeLLaMA_13B_Instruct"}};
{"CodeLLaMA_13B_Instruct", "models/CodeLLaMA_13B_Instruct"},
{"StarCoder", "models/StarCoder"},
{"StarCoder_15.5B", "models/StarCoder"}
};

std::map<std::string, int> data_format_list = {
{"FP32", FP32}, {"INT8", QINT8}, {"INT4", INT4}, {"int4", INT4}, {"fp32", FP32},
Expand All @@ -43,6 +48,15 @@ bool isCodeLLaMA(std::string s) {
return false;
}

bool isStarCoder(std::string s) {
std::string StarCoder_prefix = "StarCoder";

if (s.substr(0, StarCoder_prefix.size()) == StarCoder_prefix)
return true;
else
return false;
}

bool convertToBool(const char* str) {
if (strcmp(str, "true") == 0 || strcmp(str, "1") == 0) {
return true;
Expand Down Expand Up @@ -124,7 +138,15 @@ int main(int argc, char* argv[]) {
std::cout << "Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq" << std::endl;
else
std::cout << "Using data format: " << target_data_format << std::endl;
} else { // OPT
}
else if (isStarCoder(target_model)) {
std::cout << "Using model: " + target_model << std::endl;
if (target_data_format == "INT4" || target_data_format == "int4")
std::cout << "Using AWQ for 4bit quantization: https://github.com/mit-han-lab/llm-awq" << std::endl;
else
std::cout << "Using data format: " << target_data_format << std::endl;
}
else { // OPT
target_model = "OPT6.7B";
target_data_format = "INT8";
std::cout << "Using model: " + target_model << std::endl;
Expand Down Expand Up @@ -241,6 +263,56 @@ int main(int argc, char* argv[]) {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for LLaMA7B." << std::endl;
}
} else if (isStarCoder(target_model)) {
int format_id = data_format_list[target_data_format];

// Load model
std::cout << "Loading model... " << std::flush;
int model_id = model_config[target_model];
std::string m_path = model_path[target_model];

#ifdef MODEL_PREFIX
m_path = MODEL_PREFIX + m_path;
#endif

struct opt_params generation_config;
generation_config.n_predict = 128;
// generation_config.repeat_penalty = 1.1f;
generation_config.top_k = 0;
generation_config.temp = 0.2f;
generation_config.n_vocab = 49152;

if (format_id == FP32) {
Fp32GPTBigCodeForCausalLM model = Fp32GPTBigCodeForCausalLM(m_path, get_opt_model_config(model_id));
std::cout << "Finished!" << std::endl;

// Get input from the user
while (true) {
std::cout << "USER: ";
std::string input;
std::getline(std::cin, input);
std::cout << input;

GPTBigCodeGenerate(m_path, &model, StarCoder_FP32, input, generation_config, "models/starcoder_vocab.bin", true, false);
}
} else if (format_id == INT4) {
m_path = "INT4/" + m_path;
Int4GPTBigCodeForCausalLM model = Int4GPTBigCodeForCausalLM(m_path, get_opt_model_config(model_id));
std::cout << "Finished!" << std::endl;

// Get input from the user
while (true) {
std::cout << "USER: ";
std::string input;
std::getline(std::cin, input);
std::cout << input;

GPTBigCodeGenerate(m_path, &model, StarCoder_INT4, input, generation_config, "models/starcoder_vocab.bin", true, false);
}
} else {
std::cout << std::endl;
std::cerr << "At this time, we only support FP32 and INT4 for StarCoder." << std::endl;
}
} else { // OPT
#ifdef QM_CUDA
printf("OPT is not supported with CUDA backend yet.");
Expand Down
52 changes: 52 additions & 0 deletions llm/include/GPTBigCodeTokenizer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
Adapted from llama.cpp and starcoder.cpp:
https://github.com/ggerganov/llama.cpp
https://github.com/bigcode-project/starcoder.cpp
*/

#ifndef GPTBIGCODE_TOKENIZER_H
#define GPTBIGCODE_TOKENIZER_H

#include <cstdint>
#include <cstdio>
#include <iostream>
#include <map>
#include <queue>
#include <string>
#include <unordered_map>
#include <vector>
#include <random>
#include <thread>
#include <fstream>

//
// Vocab utils
//

std::string trim(const std::string & s);

std::string replace(
const std::string & s,
const std::string & from,
const std::string & to);

struct starcoder_vocab {
std::map<std::string, int32_t> token_to_id;
std::map<int32_t, std::string> id_to_token;
std::vector<std::string> special_tokens;

void add_special_token(const std::string & token);
};

/*
* Tokenizer
*/
starcoder_vocab starcoder_init_vocab(const std::string & vocab_file);

const char* starcoder_id_to_token(starcoder_vocab& vocab, int id);

int starcoder_tokenize(const starcoder_vocab &vocab, const std::string &text, std::vector<int> &final_tokens, int n_max_tokens);

#endif
7 changes: 6 additions & 1 deletion llm/include/Generate.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,10 @@ Adapted from llama.cpp:

#include "Fp32OPTForCausalLM.h"
#include "Fp32llamaForCausalLM.h"
#include "Fp32GPTBigCodeForCausalLM.h"
#include "Int4OPTForCausalLM.h"
#include "Int4llamaForCausalLM.h"
#include "Int4GPTBigCodeForCausalLM.h"
#include "OPTForCausalLM.h"
#include "OPTTokenizer.h"
#include "operators.h"
Expand Down Expand Up @@ -98,8 +100,11 @@ std::vector<int> OPTGenerate(void* model, int model_type, std::vector<int> input
const struct opt_params generation_config, Encoder* encoder = NULL,
bool interactive = false, bool voicechat = false);

enum { OPT_INT8, LLaMA_FP32, LLaMA_INT4, OPT_FP32, OPT_INT4 };
enum { OPT_INT8, LLaMA_FP32, LLaMA_INT4, OPT_FP32, OPT_INT4, StarCoder_FP32, StarCoder_INT4 };
std::string LLaMAGenerate(std::string param_path, void* model, int model_type, std::string text, const struct opt_params generation_config,
std::string voc_path, bool interactive, bool voicechat);

std::string GPTBigCodeGenerate(std::string param_path, void *model_ptr, int model_type, std::string text, const struct opt_params generation_config,
std::string voc_path, bool interactive, bool voicechat);

#endif // GENERATE_H
7 changes: 6 additions & 1 deletion llm/include/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct model_config {
rms_norm_eps(rms_norm_eps) {}
};

enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B };
enum { OPT_125M, OPT_1_3B, OPT_6_7B, LLaMA_7B, LLaMA_13B, CodeLLaMA_7B, CodeLLaMA_13B, StarCoder_15_5B };
enum { FP32, QINT8, INT4 };

const struct model_config opt_6_7B(1, 32, 32, 2048, 4096, 16384, 50272, 1, 0);
Expand All @@ -37,6 +37,8 @@ const struct model_config llama_7B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 1e-6)
const struct model_config llama_13B(1, 40, 40, 2048, 5120, 13824, 32000, 1, 1e-6);
const struct model_config codellama_7B(1, 32, 32, 2048, 4096, 11008, 32016, 1, 1e-5);
const struct model_config codellama_13B(1, 40, 40, 2048, 5120, 13824, 32016, 1, 1e-5);
// const struct model_config starcoder_15_5B(1, 32, 32, 2048, 4096, 11008, 32000, 1, 0); // temporary
const struct model_config starcoder_15_5B(1, 48, 40, 2048, 6144, 24576, 49152, 1, 0);
static struct model_config get_opt_model_config(int choise) {
struct model_config ret;
switch (choise) {
Expand All @@ -61,6 +63,9 @@ static struct model_config get_opt_model_config(int choise) {
case CodeLLaMA_13B:
ret = codellama_13B;
break;
case StarCoder_15_5B:
ret = starcoder_15_5B;
break;
default:
throw("Unsupported model choice.");
break;
Expand Down
47 changes: 47 additions & 0 deletions llm/include/nn_modules/Fp32GPTBigCodeAttention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
#include <utility>

#include "common.h"
#include "operators.h"

struct Fp32GPTBigCodeAttention_output {
Matrix3D<float> attn_output;
Matrix3D<float> attn_probs_reshaped;
std::pair<Matrix3D<float>, Matrix3D<float>> past_key_value;
};
struct Fp32GPTBigCodeAttention_input {
Matrix3D<float> hidden_states;
Matrix3D<float> attention_mask;
Matrix3D<float> past_key, past_value;
bool has_past_key_value = false;
int layer_idx;

Fp32GPTBigCodeAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, int layer_idx_)
: hidden_states(hidden_states_), attention_mask(attention_mask_), layer_idx(layer_idx_) {}

Fp32GPTBigCodeAttention_input(Matrix3D<float> hidden_states_, Matrix3D<float> attention_mask_, Matrix3D<float> past_key_,
Matrix3D<float> past_value_, bool has_past_key_value_, int layer_idx_)
: hidden_states(hidden_states_),
attention_mask(attention_mask_),
past_key(past_key_),
past_value(past_value_),
has_past_key_value(has_past_key_value_),
layer_idx(layer_idx_) {}
};

class Fp32GPTBigCodeAttention {
public:
Fp32GPTBigCodeAttention(std::string param_path, const struct model_config config);
Fp32GPTBigCodeAttention() {}
static void initialized_memory(const struct model_config config);
struct Fp32GPTBigCodeAttention_output forward(const struct Fp32GPTBigCodeAttention_input &input);

private:
void unshape(Matrix3D<float> shaped, Matrix3D<float> unshape, int sqlen);
void shape_qkv(Matrix3D<float> unshape, Matrix3D<float> shaped_q, Matrix3D<float> shaped_k,
Matrix3D<float> shaped_v, int sqlen);
float scaling;
int embed_dim, num_heads, head_dim, kv_heads, kv_dim;
BMM_F32T qk_bmm, pv_bmm;
Linear_FP c_attn, c_proj;
std::string profile_name = "Fp32GPTBigCodeAttention";
};
44 changes: 44 additions & 0 deletions llm/include/nn_modules/Fp32GPTBigCodeDecoder.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
#include <cstdlib>
#include <string>
#include <vector>

#include "Fp32GPTBigCodeDecoderLayer.h"
#include "common.h"
#include "operators.h"

struct Fp32GPTBigCodeDecoder_output {
Matrix3D<float> last_hidden_state;
std::vector<Matrix3D<float>> past_keys, past_values;
};
struct Fp32GPTBigCodeDecoder_input {
Matrix3D<int> input_ids;
std::vector<Matrix3D<float>> past_keys, past_values;
bool has_past_keys_values;

Fp32GPTBigCodeDecoder_input(Matrix3D<int> input_ids_) : input_ids(input_ids_) { has_past_keys_values = false; }
Fp32GPTBigCodeDecoder_input(Matrix3D<int> input_ids_, std::vector<Matrix3D<float>> past_keys_,
std::vector<Matrix3D<float>> past_values_)
: input_ids(input_ids_), past_keys(past_keys_), past_values(past_values_) {
has_past_keys_values = true;
}
};

class Fp32GPTBigCodeDecoder {
public:
Fp32GPTBigCodeDecoder(std::string param_path, const struct model_config config);
Fp32GPTBigCodeDecoder(){};
Matrix3D<float> prepare_decoder_attention_mask(int length, int past_length);
Matrix3D<float> get_position_embed(int sql_length, int past_length);
struct Fp32GPTBigCodeDecoder_output forward(const struct Fp32GPTBigCodeDecoder_input& input);
Embedding wte, wpe;
int voc_size, embed_dim, padding_idx, hidden_dim, num_heads, max_position_embeddings;
std::vector<Fp32GPTBigCodeDecoderLayer> layers;
LayerNorm ln_f;
std::string profile_name = "Fp32GPTBigCodeDecoder";

private:
float* attention_mask_buf;
float* pos_embeds_buf;
float* last_hidden_states_buf;
float* hidden_states_buf;
};
Loading

0 comments on commit c3b94c7

Please sign in to comment.