diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index 0162124cab6b..ec1cc3593a53 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -1184,9 +1184,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "The parent sequence \"" << parent_seq_id << "\" cannot be found in KV cache."; CHECK(seq_map_.find(child_seq_id) == seq_map_.end()) << "The child sequence \"" << child_seq_id << "\" is already in the KV cache."; - CHECK_EQ(parent_it->second.sliding_window_size, -1) - << "The parent sequence \"" << parent_seq_id - << "\" is enabled with sliding window and thus cannot be forked."; CHECK_GE(fork_pos, -1) << "The forked position should be non-negative, or -1 for last position as default."; CHECK_LE(fork_pos, parent_it->second.seq_length) @@ -1199,6 +1196,18 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { fork_pos = parent_it->second.seq_length; } + if (parent_it->second.sliding_window_size != -1) { + // If forked sequence has been enabled sliding window, check the forked position is within + // sliding window sink size. + const Sequence& seq = parent_it->second; + int32_t sink_size = seq.seq_length - global_block_pool_[seq.last_block_idx].seq_length + + seq.last_block_attn_sink_size; + CHECK_LE(fork_pos, sink_size) + << "The parent sequence \"" << parent_seq_id + << "\" is enabled with sliding window and thus only can be forked within sink size = " + << sink_size << ". But the forked position = " << fork_pos << "."; + } + if (fork_pos == parent_it->second.seq_length && fork_pos % page_size_ == 0 && global_block_pool_[parent_it->second.last_block_idx].seq_length > 0) { // To enable the parent sequence to continue decode after the fork, @@ -1258,6 +1267,14 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Update in-block sequence length per blocks global_block_pool_[parent_block_idx].seq_length = moved_offset; global_block_pool_[forked_block_idx].seq_length -= moved_offset; + + // Update sliding window sink size if sliding window is enabled and the forked block is the + // last block + if (parent_it->second.sliding_window_size != -1 && + forked_block_idx == parent_it->second.last_block_idx) { + CHECK_LE(moved_offset, parent_it->second.last_block_attn_sink_size); + parent_it->second.last_block_attn_sink_size -= moved_offset; + } } global_block_pool_[child_block_idx].start_pos = fork_pos - in_page_offset; global_block_pool_[child_block_idx].seq_length = in_page_offset; diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 87256720bdec..34680160c8de 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -468,8 +468,11 @@ def apply_attention( for seq_id, _ in batch: if sliding_window_sizes is not None and len(sliding_window_sizes) > seq_id: + assert len(sliding_window_sizes) > seq_id and len(attn_sink_sizes) > seq_id sliding_window_size = sliding_window_sizes[seq_id] attn_sink_size = attn_sink_sizes[seq_id] + if sliding_window_size == 0: + continue if cached_k[seq_id].shape[1] > sliding_window_size: # Apply sliding window and sink to cached kv. length_to_slide = cached_k[seq_id].shape[1] - sliding_window_size @@ -746,34 +749,74 @@ def test_paged_attention_kv_cache_sliding_window(kv_cache_and_config): attn_sink_sizes, ) - # Sliding window with fork - sliding_window_sizes += [0, 18] - attn_sink_sizes += [0, 12] - apply_attention(kv_cache, rope_mode, [(5, 10)], cached_k, cached_v) - ffork_sequence(kv_cache, 5, 6, -1) - cached_k[6] = cached_k[5] - cached_v[6] = cached_v[5] + +@tvm.testing.requires_gpu +@tvm.testing.requires_cuda +def test_paged_attention_kv_cache_sliding_window_fork(kv_cache_and_config): + kv_cache, rope_mode, support_sliding_window = kv_cache_and_config + if not support_sliding_window or rope_mode == RopeMode.NORMAL: + return + fclear(kv_cache) + + cached_k = {} + cached_v = {} + sliding_window_sizes = [30, 35, 40] + attn_sink_sizes = [15, 20, 25] + for seq_id, (sliding_window_size, attn_sink_size) in enumerate( + zip(sliding_window_sizes, attn_sink_sizes) + ): + fadd_sequence(kv_cache, seq_id) + fenable_sliding_window_for_seq(kv_cache, seq_id, sliding_window_size, attn_sink_size) + cached_k[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + cached_v[seq_id] = np.zeros((num_layers, 0, num_kv_heads, head_dim), dtype) + apply_attention( + kv_cache, + rope_mode, + [(0, 12), (1, 18), (2, 28)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [12, 18, 25+3] + sliding_window_sizes += [0, 0, 0] + attn_sink_sizes += [0, 0, 0] + apply_attention( + kv_cache, + rope_mode, + [((3, 0, 10), 8), ((4, 1, -1), 20), ((5, 2, 18), 18)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [12, 18, 25+3, 18, 38, 36] + apply_attention( + kv_cache, + rope_mode, + [(0, 9), (1, 15), (2, 4), (3, 10), (4, 3), (5, 7)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [15+6, 20+13, 25+7, 28, 41, 43] + sliding_window_sizes += [25] + attn_sink_sizes += [24] + ffork_sequence(kv_cache, 3, 6, 18) fenable_sliding_window_for_seq(kv_cache, 6, sliding_window_sizes[-1], attn_sink_sizes[-1]) - for _ in range(2): - apply_attention( - kv_cache, - rope_mode, - [(6, 10)], - cached_k, - cached_v, - sliding_window_sizes, - attn_sink_sizes, - ) - for _ in range(16): - apply_attention( - kv_cache, - rope_mode, - [(6, 1)], - cached_k, - cached_v, - sliding_window_sizes, - attn_sink_sizes, - ) + cached_k[6] = cached_k[3][::, :18] + cached_v[6] = cached_v[3][::, :18] + apply_attention( + kv_cache, + rope_mode, + [(3, 10), (6, 12)], + cached_k, + cached_v, + sliding_window_sizes, + attn_sink_sizes, + ) + # seq_len: [15+6, 20+13, 25+7, 38, 41, 43, 24+6] @tvm.testing.requires_gpu