Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More gpu memory saving for llama #20

Open
wants to merge 24 commits into
base: corvo
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
"unordered_set": "cpp",
"future": "cpp",
"cfenv": "cpp",
"typeindex": "cpp"
"typeindex": "cpp",
"__config": "cpp"
}
}
}
7 changes: 6 additions & 1 deletion examples/cpp/llama/llama_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,12 @@ void llama_example(const INIReader reader)
cudaD2Hcpy(seqlBuf, d_sequence_lengths, seqLCount);
cudaD2Hcpy(inlBuf, d_sequence_lengths, seqLCount);
printf("seqlBuf: %d\n", seqlBuf[0]);

/*
golden request:
1, 18637, 29892, 526, 366, 1136, 455, 2470, 29973, 1815, 366, 5193, 304, 592, 29973
golden result:
1 18637 29892 526 366 1136 455 2470 29973 1815 366 5193 304 592 29973 31489 25709 29251 25143 9777 24957 12623 29013 25302 11973 886 29457 6626 13638 10893 26609 25049 15066 29013 1927 27436 28754 1740 698 24551 25482 31552 22617 1140 293 10146 912
*/
{
std::cout << "Writing " << outCount << " elements\n";
int zeroCount = 0;
Expand Down
35 changes: 23 additions & 12 deletions src/fastertransformer/kernels/unfused_attention_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1334,6 +1334,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
const int batch_size,
const int seq_len,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const bool neox_rotary_style,
Expand Down Expand Up @@ -1395,18 +1396,22 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*

const int prefix_prompt_length = PREFIX_PROMPT ? param.d_prefix_prompt_lengths[batch_idx] : 0;
const int hidden_idx = head_idx * size_per_head + tidx * vec_size;
const int n = head_num * size_per_head;
const int kv_repeat_num = head_num / kv_head_num;
const int kv_hidden_idx = head_idx / kv_repeat_num * size_per_head + tidx * vec_size;
const int qkv_size = head_num * size_per_head + 2 * kv_head_num * size_per_head;
const int k_offset = head_num * size_per_head;
const int v_offset = k_offset + kv_head_num * size_per_head;

// the [0..seq_len) indices really handle KV [max_pp_len..seq_len+max_pp_len)
// and Q [0..seq_len)
// Note: if !PREFIX_PROMPT, max_pp_len = 0, so it's no-op
const int dst_kv_seq_idx = seq_idx + prefix_prompt_length;

// NOTE: q has seq len excluding prefix prompt
// src QKV: [batch, time, 3, head, hidden]
const int src_q_idx = token_idx * 3 * n + hidden_idx;
const int src_k_idx = token_idx * 3 * n + hidden_idx + n;
const int src_v_idx = token_idx * 3 * n + hidden_idx + 2 * n;
// src QKV: [batch, time, head+2*kv_head, hidden]
const int src_q_idx = token_idx * qkv_size + hidden_idx;
const int src_k_idx = token_idx * qkv_size + kv_hidden_idx + k_offset;
const int src_v_idx = token_idx * qkv_size + kv_hidden_idx + v_offset;

Vec_t q, k, v;
Vec_t q_bias, k_bias, v_bias;
Expand All @@ -1415,14 +1420,17 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
k = *reinterpret_cast<const Vec_t*>(&QKV[src_k_idx]);
v = *reinterpret_cast<const Vec_t*>(&QKV[src_v_idx]);

q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + n]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx + 2 * n]);
if (qkv_bias) {
q_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[hidden_idx]);
k_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[kv_hidden_idx + k_offset]);
v_bias = *reinterpret_cast<const Vec_t*>(&qkv_bias[kv_hidden_idx + v_offset]);
}
}
if (qkv_bias) {
q = mmha::add(q, q_bias);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);
}

q = mmha::add(q, q_bias);
k = mmha::add(k, k_bias);
v = mmha::add(v, v_bias);

if (!neox_rotary_style) {
mmha::apply_rotary_embedding(q, k, tidx, rotary_embedding_dim, rope_theta, dst_kv_seq_idx);
Expand Down Expand Up @@ -1495,6 +1503,7 @@ __global__ void add_fusedQKV_bias_transpose_kernel(T*
batch_size, \
seq_len, \
head_num, \
kv_head_num, \
size_per_head, \
rotary_embedding_dim, \
neox_rotary_style, \
Expand All @@ -1512,6 +1521,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len,
const int token_num,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
Expand Down Expand Up @@ -1571,6 +1581,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len, \
const int token_num, \
const int head_num, \
const int kv_head_num, \
const int size_per_head, \
const int rotary_embedding_dim, \
const int neox_rotary_style, \
Expand Down
3 changes: 3 additions & 0 deletions src/fastertransformer/kernels/unfused_attention_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
seq_len,
token_num,
head_num,
head_num,
size_per_head,
0,
false,
Expand Down Expand Up @@ -177,6 +178,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
seq_len,
token_num,
head_num,
head_num,
size_per_head,
rotary_embedding_dim,
neox_rotary_style,
Expand All @@ -198,6 +200,7 @@ void invokeAddFusedQKVBiasTranspose(T* q_buf,
const int seq_len,
const int token_num,
const int head_num,
const int kv_head_num,
const int size_per_head,
const int rotary_embedding_dim,
const int neox_rotary_style,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,17 +138,17 @@ void LlamaContextAttentionLayer<T>::forward(TensorMap* output_ten
local_qkv_size, // n
attention_input,
hidden_units_, // k
qkv_buf_tmp_,
qkv_buf_,
local_qkv_size /* n */);
if (local_kv_head_num_ != local_head_num_) {
invokeRepeatKv(qkv_buf_,
qkv_buf_tmp_,
local_head_num_,
local_kv_head_num_,
size_per_head_,
m,
stream_);
}
// if (local_kv_head_num_ != local_head_num_) {
// invokeRepeatKv(qkv_buf_,
// qkv_buf_tmp_,
// local_head_num_,
// local_kv_head_num_,
// size_per_head_,
// m,
// stream_);
// }

// {
// const int head_num = 6;
Expand Down Expand Up @@ -327,12 +327,13 @@ void LlamaContextAttentionLayer<T>::forward(TensorMap* output_ten
v_buf_2_,
param, // prefix prompt
qkv_buf_,
attention_weights->query_weight.bias,
(T*)(nullptr),
padding_offset,
request_batch_size,
request_seq_len,
m,
local_head_num_,
local_kv_head_num_,
size_per_head_,
rotary_embedding_dim_,
neox_rotary_style_,
Expand Down Expand Up @@ -729,13 +730,8 @@ void LlamaContextAttentionLayer<T>::allocateBuffer(size_t batch_size, size_t seq
// const auto type_size = int8_mode_ == 2 ? sizeof(int8_t) : sizeof(T);
// NOTE (perkzz): use sizeof(T) here for cutlass int8 kernels.
const auto type_size = sizeof(T);
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, type_size * 3 * batch_size * seq_len * local_hidden_units_, true);
if (local_kv_head_num_ != local_head_num_) {
size_t local_qkv_size = local_hidden_units_ + 2 * local_kv_head_num_ * size_per_head_;
qkv_buf_tmp_ = (T*)allocator_->reMalloc(qkv_buf_tmp_, type_size * batch_size * seq_len * local_qkv_size, true);
} else {
qkv_buf_tmp_ = qkv_buf_;
}
size_t local_qkv_size = local_hidden_units_ + 2 * local_kv_head_num_ * size_per_head_;
qkv_buf_ = (T*)allocator_->reMalloc(qkv_buf_, type_size * batch_size * seq_len * local_qkv_size, true);
q_buf_2_ = (T*)allocator_->reMalloc(q_buf_2_, sizeof(T) * batch_size * seq_len * 3 * local_hidden_units_, true);
k_buf_2_ = q_buf_2_ + batch_size * seq_len * local_hidden_units_;
v_buf_2_ = k_buf_2_ + batch_size * seq_len * local_hidden_units_;
Expand Down Expand Up @@ -789,9 +785,6 @@ void LlamaContextAttentionLayer<T>::freeBuffer()
if (is_allocate_buffer_) {
FT_LOG_DEBUG(__PRETTY_FUNCTION__);
allocator_->free((void**)(&qkv_buf_));
if (local_kv_head_num_ != local_head_num_) {
allocator_->free((void**)(&qkv_buf_tmp_));
}
allocator_->free((void**)(&q_buf_2_));
allocator_->free((void**)(&qk_buf_));
allocator_->free((void**)(&qkv_buf_2_));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ class LlamaContextAttentionLayer: public BaseAttentionLayer<T> {
using BaseAttentionLayer<T>::stream_;
using BaseAttentionLayer<T>::sparse_;
T* qkv_buf_ = nullptr;
T* qkv_buf_tmp_ = nullptr;
T* q_buf_2_ = nullptr;
T* k_buf_2_ = nullptr;
T* v_buf_2_ = nullptr;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ void LlamaDecoderLayerWeight<T>::copyFrom(const LlamaDecoderLayerWeight& other)
cudaD2Dcpy(weight_only_scale_ptr[1], other.weight_only_scale_ptr[1], hidden_units_);
cudaD2Dcpy(weight_only_scale_ptr[2], other.weight_only_scale_ptr[2], inter_size_ / tensor_para_size_);

// TODO: 不太清楚这里存的缩放因子对应的是gate_pro_weight 还是给 up_proj/down_proj用的,后面做一下验证,回来再改一下
// TODO: not sure gate_pro_weight corresponds to up_proj or down_proj
cudaD2Dcpy(weight_only_scale_ptr[3], other.weight_only_scale_ptr[3], inter_size_ / tensor_para_size_);
cudaD2Dcpy(weight_only_scale_ptr[4], other.weight_only_scale_ptr[4], hidden_units_);
}
Expand Down