Skip to content

Commit

Permalink
refactor kvcache manager and rotary_embedding and kvcache_memcpy oper…
Browse files Browse the repository at this point in the history
…ator
  • Loading branch information
SunflowerAries committed Apr 26, 2024
1 parent 5be590b commit 5630324
Show file tree
Hide file tree
Showing 10 changed files with 175 additions and 159 deletions.
23 changes: 17 additions & 6 deletions colossalai/inference/kv_cache/kvcache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,18 @@ def __init__(self, config: InferenceConfig, model_config: PretrainedConfig, verb
self.num_blocks = self.max_blocks_per_sequence * self.max_batch_size * self.beam_width

# Physical cache allocation
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape)
if config.use_cuda_kernel:
x = 16 // torch.tensor([], dtype=config.dtype).element_size()
kalloc_shape = (self.num_blocks, self.kv_head_num, self.head_size // x, self.block_size, x)
valloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(
f"Allocating K cache with shape: {kalloc_shape}, V cache with shape: {valloc_shape} consisting of {self.num_blocks} blocks."
)
self._kv_caches = self._init_device_caches(kalloc_shape, valloc_shape)
else:
alloc_shape = (self.num_blocks, self.kv_head_num, self.block_size, self.head_size)
self.logger.info(f"Allocating KV cache with shape: {alloc_shape} consisting of {self.num_blocks} blocks.")
self._kv_caches = self._init_device_caches(alloc_shape, alloc_shape)
self.total_physical_cache_size_in_bytes = (
self.elem_size_in_bytes
* self.num_layers
Expand Down Expand Up @@ -479,7 +488,9 @@ def _init_logical_caches(self):
blocks.append(cache_block)
return blocks

def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tensor, torch.Tensor]:
def _init_device_caches(
self, kalloc_shape: Tuple[int, ...], valloc_shape: Tuple[int, ...]
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Initialize the physical cache on the device.
For each layer of the model, we allocate two tensors for key and value respectively,
Expand All @@ -488,6 +499,6 @@ def _init_device_caches(self, alloc_shape: Tuple[int, ...]) -> Tuple[torch.Tenso
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for _ in range(self.num_layers):
k_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
v_cache.append(torch.zeros(alloc_shape, dtype=self.dtype, device=self.device))
k_cache.append(torch.zeros(kalloc_shape, dtype=self.dtype, device=self.device))
v_cache.append(torch.zeros(valloc_shape, dtype=self.dtype, device=self.device))
return k_cache, v_cache
1 change: 1 addition & 0 deletions colossalai/inference/modeling/models/nopadding_baichuan.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,7 @@ def forward(
alibi_slopes=self.alibi_slopes,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
Expand Down
59 changes: 30 additions & 29 deletions colossalai/inference/modeling/models/nopadding_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -575,6 +575,7 @@ def forward(
output=output_tensor,
max_seq_len=kv_seq_len,
sm_scale=sm_scale,
use_new_kcache_layout=use_cuda_kernel,
)
else:
q_len = tokens_to_verify + 1 if is_verifier else 1
Expand All @@ -592,20 +593,20 @@ def forward(
block_tables,
high_precision,
)
# inference_ops.flash_decoding_attention(
# output_tensor,
# query_states,
# k_cache,
# v_cache,
# sequence_lengths,
# block_tables,
# block_size,
# kv_seq_len,
# fd_inter_tensor.mid_output,
# fd_inter_tensor.mid_output_lse,
# sm_scale,
# )
# attn_output = output_tensor
inference_ops.flash_decoding_attention(
output_tensor,
query_states,
k_cache,
v_cache,
sequence_lengths,
block_tables,
block_size,
kv_seq_len,
fd_inter_tensor.mid_output,
fd_inter_tensor.mid_output_lse,
sm_scale,
)
attn_output = output_tensor
else:
if is_verifier:
rotary_embedding(query_states, key_states, cos_sin[0], cos_sin[1])
Expand All @@ -627,21 +628,21 @@ def forward(
block_tables,
sequence_lengths,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)
attn_output = flash_decoding_attention(
q=query_states,
k_cache=k_cache,
v_cache=v_cache,
kv_seq_len=sequence_lengths,
block_tables=block_tables,
block_size=block_size,
max_seq_len_in_batch=kv_seq_len,
output=output_tensor,
mid_output=fd_inter_tensor.mid_output,
mid_output_lse=fd_inter_tensor.mid_output_lse,
sm_scale=sm_scale,
kv_group_num=self.num_key_value_groups,
q_len=q_len,
)

attn_output = attn_output.view(-1, self.hidden_size)
attn_output = self.o_proj(attn_output)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@

from colossalai.kernel.kernel_loader import InferenceOpsLoader
from colossalai.kernel.triton import copy_kv_to_blocked_cache, decoding_fused_rotary_embedding, rotary_embedding
from tests.test_infer.test_ops.triton.kernel_utils import mock_alloc_block_table_and_kvcache_v2, mock_alloc_single_token
from tests.test_infer.test_ops.triton.kernel_utils import (
mock_alloc_block_table_and_kvcache_v2,
mock_alloc_block_table_and_kvcache_v3,
mock_alloc_single_token,
)

inference_ops = InferenceOpsLoader().load()

Expand Down Expand Up @@ -68,11 +72,17 @@ def benchmark_rotary_emb(
cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, block_size, head_dim)
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device="cuda")
x = 16 // torch.tensor([], dtype=dtype).element_size()
new_cache_shape = (BATCH_SIZE * max_num_blocks_per_seq, num_kv_heads, head_dim // x, block_size, x)
new_k_cache = torch.zeros(size=new_cache_shape, dtype=dtype, device="cuda")

past_kv_seq_lengths = torch.tensor([SEQ_LEN - 1 for _ in range(BATCH_SIZE)], dtype=torch.int32, device="cuda")
block_tables = mock_alloc_block_table_and_kvcache_v2(
k, v, k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
_ = mock_alloc_block_table_and_kvcache_v3(
k, v, new_k_cache, v_cache, past_kv_seq_lengths, BATCH_SIZE, max_num_blocks_per_seq, block_size
)
new_k = torch.randn((BATCH_SIZE, num_kv_heads, head_dim), dtype=dtype, device="cuda")
new_q = torch.randn_like(new_k)
new_v = torch.randn_like(new_k)
Expand All @@ -94,12 +104,12 @@ def benchmark_rotary_emb(
)
elif provider == "no_fused_cuda_rotary_emb_func":
fn = lambda: [
inference_ops.rotary_embedding(new_q, new_k, cos, sin),
inference_ops.rotary_embedding(new_q, new_k, cos, sin, True),
inference_ops.decode_kv_cache_memcpy(new_k, new_v, k_cache, v_cache, kv_seq_lengths, block_tables),
]
elif provider == "fused_cuda_rotary_emb_func":
fn = lambda: inference_ops.rotary_embedding_and_cache_copy(
new_q, new_k, new_v, cos, sin, k_cache, v_cache, kv_seq_lengths, block_tables
new_q, new_k, new_v, cos, sin, new_k_cache, v_cache, kv_seq_lengths, block_tables, True
)
else:
raise ValueError("Undefined provider")
Expand Down
46 changes: 33 additions & 13 deletions extensions/csrc/kernel/cuda/context_kv_cache_memcpy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,15 @@ __global__ void context_kv_cache_memcpy_kernel(
const int batch_size,
const int block_table_stride,
const int64_t key_stride,
const int64_t value_stride
const int64_t value_stride,
const int x
)
{
const int seq_token_id = blockIdx.x;
const int seq_id = blockIdx.y;
const int block_id = block_tables[seq_id * block_table_stride + seq_token_id / block_size];

if ( block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
if (block_id < 0 || seq_token_id > sequence_lengths[seq_id] - 1) {
return ;
}

Expand All @@ -39,38 +40,55 @@ __global__ void context_kv_cache_memcpy_kernel(
const int total_token_id = cu_seqlens[seq_id] + seq_token_id;
int head_id;
int head_offset;
int x_id;
int x_offset;
int64_t key_src_id;
int64_t value_src_id;
int64_t target_id;
int64_t target_key_id;
int64_t target_value_id;

int i = threadIdx.x * VecSize;

for (; i <= (hidden_size - VecSize); i += blockDim.x * VecSize) {
head_id = i / head_dim;
head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

copy_vector<scalar_t, VecSize>(key_cache + target_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_id, value + value_src_id);
copy_vector<scalar_t, VecSize>(key_cache + target_key_id, key + key_src_id);
copy_vector<scalar_t, VecSize>(value_cache + target_value_id, value + value_src_id);
}

// tail process
if (!Aligned) {
for (; i < hidden_size; ++i ) {
head_id = i / head_dim;
head_offset = i % head_dim;
x_id = head_offset / x;
x_offset = head_offset % x;
key_src_id = total_token_id * key_stride + i;
value_src_id = total_token_id * value_stride + i;
target_id = block_id * hidden_size * block_size
target_key_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ x_id * block_size * x
+ block_offset * x
+ x_offset;
target_value_id = block_id * hidden_size * block_size
+ head_id * block_size * head_dim
+ block_offset * head_dim + head_offset;

key_cache[target_id] = key[key_src_id];
value_cache[target_id] = value[value_src_id];
key_cache[target_key_id] = key[key_src_id];
value_cache[target_value_id] = value[value_src_id];
}
}

Expand All @@ -80,7 +98,7 @@ template<typename scalar_t>
void apply_context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
Expand All @@ -90,7 +108,8 @@ void apply_context_kv_cache_memcpy(
int num_tokens = key.size(0);
int head_num = key.size(1);
int head_dim = key.size(2);
int block_size = key_cache.size(2);
int block_size = key_cache.size(3);
int x = key_cache.size(4);
int batch_size = block_tables.size(0);

int64_t key_stride = key.stride(0);
Expand Down Expand Up @@ -126,7 +145,8 @@ void apply_context_kv_cache_memcpy(
batch_size, \
block_table_stride, \
key_stride, \
value_stride \
value_stride, \
x \
); \
} while(0)

Expand Down Expand Up @@ -163,7 +183,7 @@ void apply_context_kv_cache_memcpy(
void context_kv_cache_memcpy(
at::Tensor& key, // [num_tokens, head_num, head_dim]
at::Tensor& value, // [num_tokens, head_num, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& key_cache, // [num_blocks, head_num, head_dim/x, block_size, x]
at::Tensor& value_cache, // [num_blocks, head_num, block_size, head_dim]
at::Tensor& sequence_lengths, // [batch_size]
at::Tensor& cu_seqlens, // [batch_size + 1]
Expand Down
Loading

0 comments on commit 5630324

Please sign in to comment.