Skip to content

Commit

Permalink
[KVCache] Support fork in sliding window sink part
Browse files Browse the repository at this point in the history
This PR adds the support of forking in sliding window attention sink part.
  • Loading branch information
cyx-6 committed Jul 1, 2024
1 parent 4a5e22e commit 9198e9c
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 30 deletions.
23 changes: 20 additions & 3 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1184,9 +1184,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
<< "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache.";
CHECK(seq_map_.find(child_seq_id) == seq_map_.end())
<< "The child sequence \"" << child_seq_id << "\" is already in the KV cache.";
CHECK_EQ(parent_it->second.sliding_window_size, -1)
<< "The parent sequence \"" << parent_seq_id
<< "\" is enabled with sliding window and thus cannot be forked.";
CHECK_GE(fork_pos, -1)
<< "The forked position should be non-negative, or -1 for last position as default.";
CHECK_LE(fork_pos, parent_it->second.seq_length)
Expand All @@ -1199,6 +1196,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
fork_pos = parent_it->second.seq_length;
}

if (parent_it->second.sliding_window_size != -1) {
// If forked sequence has been enabled sliding window, check the forked position is within
// sliding window sink size.
const Sequence& seq = parent_it->second;
int32_t sink_size = seq.seq_length - global_block_pool_[seq.last_block_idx].seq_length +
seq.last_block_attn_sink_size;
CHECK_LE(fork_pos, sink_size)
<< "The parent sequence \"" << parent_seq_id
<< "\" is enabled with sliding window and thus only can be forked within sink size = "
<< sink_size << ". But the forked position = " << fork_pos << ".";
}

if (fork_pos == parent_it->second.seq_length && fork_pos % page_size_ == 0 &&
global_block_pool_[parent_it->second.last_block_idx].seq_length > 0) {
// To enable the parent sequence to continue decode after the fork,
Expand Down Expand Up @@ -1258,6 +1267,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// Update in-block sequence length per blocks
global_block_pool_[parent_block_idx].seq_length = moved_offset;
global_block_pool_[forked_block_idx].seq_length -= moved_offset;

// Update sliding window sink size if sliding window is enabled and the forked block is the
// last block
if (parent_it->second.sliding_window_size != -1 &&
forked_block_idx == parent_it->second.last_block_idx) {
CHECK_LE(moved_offset, parent_it->second.last_block_attn_sink_size);
parent_it->second.last_block_attn_sink_size -= moved_offset;
}
}
global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset;
global_block_pool_[child_block_idx].seq_length = in_page_offset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,11 @@ def apply_attention(

for seq_id, _ in batch:
if sliding_window_sizes is not None and len(sliding_window_sizes) > seq_id:
assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes) > seq_id
sliding_window_size = sliding_window_sizes[seq_id]
attn_sink_size = attn_sink_sizes[seq_id]
if sliding_window_size == 0:
continue
if cached_k[seq_id].shape[1] > sliding_window_size:
# Apply sliding window and sink to cached kv.
length_to_slide = cached_k[seq_id].shape[1] - sliding_window_size
Expand Down Expand Up @@ -746,34 +749,74 @@ def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config):
attn_sink_sizes,
)

# Sliding window with fork
sliding_window_sizes += [0, 18]
attn_sink_sizes += [0, 12]
apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v)
ffork_sequence(kv_cache, 5, 6, -1)
cached_k[6] = cached_k[5]
cached_v[6] = cached_v[5]

@tvm.testing.requires_gpu
@tvm.testing.requires_cuda
def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config):
kv_cache, rope_mode, support_sliding_window = kv_cache_and_config
if not support_sliding_window or rope_mode == RopeMode.NORMAL:
return
fclear(kv_cache)

cached_k = {}
cached_v = {}
sliding_window_sizes = [30, 35, 40]
attn_sink_sizes = [15, 20, 25]
for seq_id, (sliding_window_size, attn_sink_size) in enumerate(
zip(sliding_window_sizes, attn_sink_sizes)
):
fadd_sequence(kv_cache, seq_id)
fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, attn_sink_size)
cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype)
apply_attention(
kv_cache,
rope_mode,
[(0, 12), (1, 18), (2, 28)],
cached_k,
cached_v,
sliding_window_sizes,
attn_sink_sizes,
)
# seq_len: [12, 18, 25+3]
sliding_window_sizes += [0, 0, 0]
attn_sink_sizes += [0, 0, 0]
apply_attention(
kv_cache,
rope_mode,
[((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)],
cached_k,
cached_v,
sliding_window_sizes,
attn_sink_sizes,
)
# seq_len: [12, 18, 25+3, 18, 38, 36]
apply_attention(
kv_cache,
rope_mode,
[(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)],
cached_k,
cached_v,
sliding_window_sizes,
attn_sink_sizes,
)
# seq_len: [15+6, 20+13, 25+7, 28, 41, 43]
sliding_window_sizes += [25]
attn_sink_sizes += [24]
ffork_sequence(kv_cache, 3, 6, 18)
fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], attn_sink_sizes[-1])
for _ in range(2):
apply_attention(
kv_cache,
rope_mode,
[(6, 10)],
cached_k,
cached_v,
sliding_window_sizes,
attn_sink_sizes,
)
for _ in range(16):
apply_attention(
kv_cache,
rope_mode,
[(6, 1)],
cached_k,
cached_v,
sliding_window_sizes,
attn_sink_sizes,
)
cached_k[6] = cached_k[3][::, :18]
cached_v[6] = cached_v[3][::, :18]
apply_attention(
kv_cache,
rope_mode,
[(3, 10), (6, 12)],
cached_k,
cached_v,
sliding_window_sizes,
attn_sink_sizes,
)
# seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6]


@tvm.testing.requires_gpu
Expand Down

0 comments on commit 9198e9c

Please sign in to comment.