From 18a2a250f8c7f16f5f5be6753861ba5db8fb89fa Mon Sep 17 00:00:00 2001 From: Yaxing Cai Date: Mon, 20 May 2024 08:13:50 -0700 Subject: [PATCH] [KVCache] Support KVCache decode from forked sequence and pop more tokens (#16995) --- src/runtime/relax_vm/paged_kv_cache.cc | 65 +++++++++++++++++++++----- 1 file changed, 53 insertions(+), 12 deletions(-) diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index b07ae3d76d23..a5d2d9f41554 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -925,10 +925,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) { // Fork at last by appending a new block directly int32_t parent_block_idx = parent_it->second.last_block_idx; + if (!global_block_pool_[parent_block_idx].seq_length) { + // If parent ends with empty block, fork from parent's parent block + parent_block_idx = global_block_pool_[parent_block_idx].parent_idx; + } ++global_block_pool_[parent_block_idx].external_ref_cnt; // Update child block start position and parent index global_block_pool_[child_block_idx].start_pos = parent_it->second.seq_length; global_block_pool_[child_block_idx].parent_idx = parent_block_idx; + if (global_block_pool_[parent_block_idx].seq_length) { + // If parent is not empty, append a new block + int32_t new_parent_block_idx = GetFreeBlock(); + global_block_pool_[new_parent_block_idx].start_pos = parent_it->second.seq_length; + global_block_pool_[new_parent_block_idx].parent_idx = parent_block_idx; + parent_it->second.last_block_idx = new_parent_block_idx; + } } else { // Locate the block to fork from and calculate in-block offset std::vector trace = parent_it->second.GetBlockTrace(global_block_pool_); @@ -1038,21 +1049,51 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { auto it = seq_map_.find(seq_id); CHECK(it != seq_map_.end()) << "The sequence \"" << seq_id << "\" cannot be found in KV cache."; - Block& block = global_block_pool_[it->second.last_block_idx]; CHECK_GE(n, 0) << "The length of popping " << n << " cannot be negative."; - CHECK_LE(n, block.seq_length) << "The sequence only has length " << block.seq_length - << " in the last block, while the length of pop is " << n - << " which exceeds the last-block sequence length."; + CHECK_LE(n, it->second.seq_length) + << "The sequence only has length " << it->second.seq_length + << ", while the length of pop is " << n << " which exceeds the whole sequence length."; + int32_t block_idx = it->second.last_block_idx; + while (block_idx != -1 && global_block_pool_[block_idx].external_ref_cnt == 0) { + if (n > global_block_pool_[block_idx].seq_length) { + n -= global_block_pool_[block_idx].seq_length; + it->second.seq_length -= global_block_pool_[block_idx].seq_length; + for (int32_t page_id : global_block_pool_[block_idx].page_ids) { + free_page_ids_.push_back(page_id); + } + free_block_idx_.push_back(block_idx); + block_idx = global_block_pool_[block_idx].parent_idx; + it->second.last_block_idx = block_idx; + continue; + } + if (n <= global_block_pool_[block_idx].seq_length) { + int64_t cur_npage = global_block_pool_[block_idx].page_ids.size(); + int64_t tgt_npage = + (global_block_pool_[block_idx].seq_length - n + page_size_ - 1) / page_size_; + while (cur_npage > tgt_npage) { + free_page_ids_.push_back(global_block_pool_[block_idx].page_ids.back()); + global_block_pool_[block_idx].page_ids.pop_back(); + --cur_npage; + } + it->second.seq_length -= n; + global_block_pool_[block_idx].seq_length -= n; + n = 0; + break; + } + } - int64_t cur_npage = block.page_ids.size(); - int64_t tgt_npage = (block.seq_length - n + page_size_ - 1) / page_size_; - while (cur_npage > tgt_npage) { - free_page_ids_.push_back(block.page_ids.back()); - block.page_ids.pop_back(); - --cur_npage; + if (n) { + int32_t temp_seq_id = -1 - seq_id; + CHECK(seq_map_.find(temp_seq_id) == seq_map_.end()); + ForkSequence(seq_id, temp_seq_id, it->second.seq_length - n); + CHECK(seq_map_.find(temp_seq_id) != seq_map_.end()); + RemoveSequence(seq_id); + CHECK(seq_map_.find(seq_id) == seq_map_.end()); + auto it = seq_map_.find(temp_seq_id); + seq_map_.insert({seq_id, Sequence(global_block_pool_, it->second.last_block_idx)}); + seq_map_.erase(temp_seq_id); } - it->second.seq_length -= n; - block.seq_length -= n; + dirty_aux_data_device_ = true; }