diff --git a/src/runtime/relax_vm/kv_state.h b/src/runtime/relax_vm/kv_state.h index 8de560f12266..f4d6036b9638 100644 --- a/src/runtime/relax_vm/kv_state.h +++ b/src/runtime/relax_vm/kv_state.h @@ -151,9 +151,11 @@ class AttentionKVCacheObj : public KVStateObj { * The commit will update the KV cache, by compacting the KV data and discard * the KV data of rejected tokens. * This is a mandatory step when the BeginForward is given with a token tree. + * \param seq_ids The ids of the sequences to commit. * \param leaf_indices The leaf token tree node index of each sequence. */ - virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0; + virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, + const IntTuple& leaf_indices) = 0; /************** Attention **************/ diff --git a/src/runtime/relax_vm/paged_kv_cache.cc b/src/runtime/relax_vm/paged_kv_cache.cc index a5b970e81716..2fc5da78e979 100644 --- a/src/runtime/relax_vm/paged_kv_cache.cc +++ b/src/runtime/relax_vm/paged_kv_cache.cc @@ -151,6 +151,18 @@ struct Sequence { */ int last_block_attn_sink_size = 0; + /*! \brief Whether the current appended tokens form a chain (not a tree). */ + bool is_chain = true; + /*! \brief The token tree parent pointer array of the current appended tokens. */ + std::vector token_tree_parent_ptr; + /*! \brief The depth of each node in the token tree. */ + std::vector token_tree_node_depths; + /*! + * \brief A boolean denoting whether the accepted token tree indices of + * this sequence are committed + */ + bool accepted_indices_committed = true; + explicit Sequence(std::vector* global_block_pool, int32_t last_block_idx) { ++global_block_pool->at(last_block_idx).external_ref_cnt; this->last_block_idx = last_block_idx; @@ -879,10 +891,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { IntTuple cur_seq_ids_; /*! \brief The append lengths of the sequences in the current round of forwarding. */ IntTuple cur_append_lengths_; - /*! \brief The token tree parent array of the sequences in the current round of forwarding. */ - IntTuple cur_token_tree_parent_ptr_{nullptr}; - /*! \brief The depth of each node in the token tree, for the sequences in the current batch. */ - std::vector> cur_token_tree_node_depths_; /*! \brief Whether the current batch of sequences are token chains (not token trees). */ bool is_chain_; /*! \brief Number of fork depth in the current round of forward. */ @@ -1187,6 +1195,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { << "The forked position should be non-negative, or -1 for last position as default."; CHECK_LE(fork_pos, parent_it->second.seq_length) << "The forked position should not exceed the total length of parent sequence."; + CHECK(parent_it->second.accepted_indices_committed) + << "The parent sequence's token tree computed in the last round of forward has not been " + "committed with accepted nodes."; int32_t child_block_idx = GetFreeBlock(); if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) { @@ -1434,10 +1445,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths, const Optional& opt_token_tree_parent_ptr) final { - CHECK(!cur_token_tree_parent_ptr_.defined()) - << "The last round of forward which involves token tree has not been committed. Please " - "call \"CommitAcceptedTreeNodes\" to commit the accepted tokens."; - CHECK_EQ(seq_ids.size(), append_lengths.size()) << "The seq_ids size (" << seq_ids.size() << ") and append_lengths size (" << append_lengths.size() << ") mismatch."; @@ -1445,14 +1452,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { cur_seq_ids_ = seq_ids; cur_append_lengths_ = append_lengths; - // - Check token tree validity and process the token tree. - is_chain_ = true; - tree_attn_mask_host_.clear(); - tree_attn_mn_indptr_host_.clear(); - if (opt_token_tree_parent_ptr.defined()) { - is_chain_ = ConstructTokenTreeMask(opt_token_tree_parent_ptr.value()); - } - // - Collect sequence/block/page information for attention. std::vector sequences; std::vector last_block_length_before_append; @@ -1474,6 +1473,29 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } + // - Check token tree validity and process the token tree. + is_chain_ = true; + tree_attn_mask_host_.clear(); + tree_attn_mn_indptr_host_.clear(); + if (opt_token_tree_parent_ptr.defined()) { + is_chain_ = ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value()); + } else { + // The input batch does not form trees. So each sequence in the batch + // is required to have all past accepted tokens committed. + for (int i = 0; i < cur_batch_size_; ++i) { + Sequence* sequence = sequences[i]; + CHECK(sequence->accepted_indices_committed) + << "The input batch does not form a tree, in which case the sequences in the input " + "batch are expected to have their accepted tokens token tree nodes committed. " + "Please invoke CommitAcceptedTokenTreeNodes for sequence " + << seq_ids[i]; + sequence->is_chain = true; + sequence->token_tree_parent_ptr.clear(); + sequence->token_tree_node_depths.clear(); + } + is_chain_ = true; + } + std::vector> block_ids_on_depths = GetBlockIdsOnDepth(sequences); num_depths_ = block_ids_on_depths.size(); ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth); @@ -1559,7 +1581,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int64_t pos = 0; pos < append_length; ++pos) { q_rope_position_map_host_.push_back( k_ragged_rope_pos_offset_host_[i] + - (is_chain_ ? pos : cur_token_tree_node_depths_[i][pos])); + (is_chain_ ? pos : sequences[i]->token_tree_node_depths[pos])); int32_t pos_in_block = block.seq_length - append_length + pos; if (last_block_length_before_append[i] + pos < block.sink_length) { @@ -1649,19 +1671,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { } } - void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) final { - CHECK_NE(cur_batch_size_, -1) - << "Cannot commit accepted token tree nodes since BeginForward is not invoked."; - CHECK_EQ(leaf_indices.size(), cur_batch_size_) - << "The number of input leaf indices does not equal to the current batch size."; + void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& leaf_indices) final { + CHECK_EQ(seq_ids.size(), leaf_indices.size()) + << "The given seq_ids and leaf_indices have different size."; + int num_seq_to_commit = seq_ids.size(); - for (int i = 0; i < cur_batch_size_; ++i) { - CHECK_GE(leaf_indices[i], 0) - << "Invalid tree index " << leaf_indices[i] << " which is negative"; - CHECK_LT(leaf_indices[i], cur_append_lengths_[i]) + std::vector sequences; + sequences.reserve(num_seq_to_commit); + for (int i = 0; i < num_seq_to_commit; ++i) { + auto it = seq_map_.find(seq_ids[i]); + CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i] + << "\" cannot be found in KV cache."; + sequences.push_back(&it->second); + CHECK(!it->second.accepted_indices_committed) + << "The accepted nodes of sequence " << seq_ids[i] << " are already committed."; + CHECK_GE(leaf_indices[i], -1) + << "Invalid tree index " << leaf_indices[i] << " which is less than -1"; + CHECK_LT(leaf_indices[i], static_cast(it->second.token_tree_parent_ptr.size())) << "Invalid tree index " << leaf_indices[i] - << " which is larger than or equals to the append length " << cur_append_lengths_[i] - << " of the sequence"; + << " which is larger than or equals to the append length " + << it->second.token_tree_parent_ptr.size() << " of the sequence"; } if (!is_chain_) { @@ -1670,16 +1699,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { commit_copy_dst_pos_in_page_table_host_.clear(); commit_copy_length_indptr_host_.push_back(0); - for (int i = 0; i < cur_batch_size_; ++i) { + for (int i = 0; i < num_seq_to_commit; ++i) { + if (leaf_indices[i] == -1) { + // No node is accepted. All nodes in the token tree need to be popped. + continue; + } + // Get the accepted node path on the token tree. std::vector path_on_tree; - path_on_tree.reserve(cur_token_tree_node_depths_[i][leaf_indices[i]] + 1); + path_on_tree.reserve(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1); int node = leaf_indices[i]; while (node != -1) { path_on_tree.push_back(node); - node = cur_token_tree_parent_ptr_[cur_append_lengths_indptr_host_[i] + node]; + node = sequences[i]->token_tree_parent_ptr[node]; } - ICHECK_EQ(path_on_tree.size(), cur_token_tree_node_depths_[i][leaf_indices[i]] + 1); + ICHECK_EQ(path_on_tree.size(), sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1); // Get the destination array (range [0, path_length - 1)) of KV cache copy. std::vector copy_dst_pos_in_seq; copy_dst_pos_in_seq.resize(path_on_tree.size()); @@ -1714,14 +1748,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { // Note: Function "PopN" only changes the page table structure and does not // change the KV cache data. Therefore, we can directly use it, since // we have already launched all copies. - for (int i = 0; i < cur_batch_size_; ++i) { + for (int i = 0; i < num_seq_to_commit; ++i) { int64_t length_to_pop = - cur_append_lengths_[i] - cur_token_tree_node_depths_[i][leaf_indices[i]] - 1; + cur_append_lengths_[i] - + (leaf_indices[i] != -1 ? (sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1) : 0); PopN(cur_seq_ids_[i], length_to_pop); + // Reset the sequence states. + sequences[i]->accepted_indices_committed = true; + sequences[i]->token_tree_parent_ptr.clear(); + sequences[i]->token_tree_node_depths.clear(); } - - // Reset the token tree. - cur_token_tree_parent_ptr_ = IntTuple{nullptr}; } NDArray GetQueryPositions() final { @@ -1814,57 +1850,67 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { return block_idx; } - bool ConstructTokenTreeMask(const IntTuple& token_tree_parent_ptr) { + bool ConstructTokenTreeMask(const std::vector& sequences, + const IntTuple& token_tree_parent_ptr) { // We check if the token tree deteriorates to a chain, // because chain cases can have simplified attention work flow. bool is_chain = true; - cur_token_tree_parent_ptr_ = token_tree_parent_ptr; - cur_token_tree_node_depths_.clear(); - cur_token_tree_node_depths_.reserve(cur_batch_size_); - - int64_t sum_append_length = 0; + int64_t sum_new_append_length = 0; // - Construct the mn indptr array, which is the indptr of the mask size of each sequence. tree_attn_mn_indptr_host_.push_back(0); - for (int64_t append_length : cur_append_lengths_) { - sum_append_length += append_length; + ICHECK_EQ(sequences.size(), cur_batch_size_); + ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_); + for (int i = 0; i < cur_batch_size_; ++i) { + int64_t append_length = cur_append_lengths_[i]; + // Update the token tree parent pointers. + sequences[i]->token_tree_parent_ptr = { + token_tree_parent_ptr->data + sum_new_append_length, + token_tree_parent_ptr->data + sum_new_append_length + cur_append_lengths_[i]}; + sum_new_append_length += cur_append_lengths_[i]; + + CHECK_LE(append_length, kTreeAttnMaxTreeSize) + << "The tree size is " << append_length << " which exceeds the maximum tree size limit " + << kTreeAttnMaxTreeSize; tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() + - static_cast(append_length * append_length)); + append_length * append_length); } - CHECK_EQ(token_tree_parent_ptr.size(), sum_append_length) - << "Invalid token tree size. The sum of \"append_lengths\" is " << sum_append_length + CHECK_EQ(token_tree_parent_ptr.size(), sum_new_append_length) + << "Invalid token tree size. The sum of \"append_lengths\" is " << sum_new_append_length << " while there are " << token_tree_parent_ptr.size() << " elements in \"token_tree_parent_ptr\"."; // - Construct the mask of each sequence. - int processed_pos = 0; for (int i = 0; i < cur_batch_size_; ++i) { - int64_t append_length = cur_append_lengths_[i]; + int64_t tree_size = sequences[i]->token_tree_parent_ptr.size(); std::vector> mask; std::vector depth; - mask.reserve(append_length); - depth.reserve(append_length); - for (int64_t n = 0; n < append_length; ++n) { - CHECK_LT(token_tree_parent_ptr[processed_pos], n) + mask.reserve(tree_size); + depth.reserve(tree_size); + sequences[i]->is_chain = true; + sequences[i]->accepted_indices_committed = false; + for (int64_t n = 0; n < tree_size; ++n) { + CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n) << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " - << token_tree_parent_ptr[processed_pos] << ", which is not smaller than " << n; - CHECK_GE(token_tree_parent_ptr[processed_pos], -1) + << sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n; + CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1) << "Invalid token tree. The parent of node " << n << " in tree " << i << " is " - << token_tree_parent_ptr[processed_pos]; - if (token_tree_parent_ptr[processed_pos] != n - 1) { + << sequences[i]->token_tree_parent_ptr[n]; + if (sequences[i]->token_tree_parent_ptr[n] != n - 1) { // The parent of the current node is not the last node. // Therefore the tree is not a chain. + sequences[i]->is_chain = false; is_chain = false; } std::vector single_pos_mask; - if (token_tree_parent_ptr[processed_pos] != -1) { + if (sequences[i]->token_tree_parent_ptr[n] != -1) { // The current node has a parent in the token tree. - single_pos_mask = {mask[token_tree_parent_ptr[processed_pos]].begin(), - mask[token_tree_parent_ptr[processed_pos]].end()}; - depth.push_back(depth[token_tree_parent_ptr[processed_pos]] + 1); + single_pos_mask = {mask[sequences[i]->token_tree_parent_ptr[n]].begin(), + mask[sequences[i]->token_tree_parent_ptr[n]].end()}; + depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1); } else { // The current node is root in the token tree. - single_pos_mask.resize(append_length, /*value=*/0); + single_pos_mask.resize(tree_size, /*value=*/0); depth.push_back(0); } single_pos_mask[n] = 1; @@ -1872,12 +1918,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj { for (int32_t mask_val : single_pos_mask) { tree_attn_mask_host_.push_back(mask_val); } - - ++processed_pos; } - cur_token_tree_node_depths_.push_back(std::move(depth)); + sequences[i]->token_tree_node_depths = std::move(depth); } - return is_chain; } diff --git a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py index 0a69d184e5a9..c5c88211ba18 100644 --- a/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py +++ b/tests/python/relax/test_runtime_builtin_paged_attention_kv_cache_tir.py @@ -438,7 +438,10 @@ def apply_attention( fend_forward(kv_cache) if accepted_leaf_indices is not None: - fcommit_accepted_token_tree_nodes(kv_cache, ShapeTuple(accepted_leaf_indices)) + seq_ids = [seq_id for seq_id, _ in batch] + fcommit_accepted_token_tree_nodes( + kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices) + ) for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate( zip(accepted_leaf_indices, batch) ): @@ -449,7 +452,7 @@ def apply_attention( node = token_tree_parent_ptr_list[i][node] offset = cached_k[seq_id].shape[1] - append_length length_to_pop = append_length - len(tree_path) - assert 0 <= length_to_pop < append_length + assert 0 <= length_to_pop <= append_length for dst_pos, src_pos in enumerate(reversed(tree_path)): if dst_pos == src_pos: continue @@ -773,7 +776,7 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config): [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10 [-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length 14 ], - accepted_leaf_indices=[2, 6, 6, 4], + accepted_leaf_indices=[2, 6, -1, 4], ) # Do 5 rounds of decode. for _ in range(5):