Skip to content

Commit

Permalink
[Runtime] Stateless interface of PagedKVCache leaf node commit (#17057)
Browse files Browse the repository at this point in the history
This PR changes the interface of the function
`CommitAcceptedTokenTreeNodeToKVCache` introduced recently for
PagedKVCache to a stateless interface. Previously the interace
is a stateful one, which makes strong assumption on the caller
side. This commit removes the assumption so that the interface
becomes less confusing.
  • Loading branch information
MasterJH5574 authored Jun 2, 2024
1 parent 4ab91d4 commit b87d1f9
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 71 deletions.
4 changes: 3 additions & 1 deletion src/runtime/relax_vm/kv_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,11 @@ class AttentionKVCacheObj : public KVStateObj {
* The commit will update the KV cache, by compacting the KV data and discard
* the KV data of rejected tokens.
* This is a mandatory step when the BeginForward is given with a token tree.
* \param seq_ids The ids of the sequences to commit.
* \param leaf_indices The leaf token tree node index of each sequence.
*/
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) = 0;
virtual void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids,
const IntTuple& leaf_indices) = 0;

/************** Attention **************/

Expand Down
177 changes: 110 additions & 67 deletions src/runtime/relax_vm/paged_kv_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,18 @@ struct Sequence {
*/
int last_block_attn_sink_size = 0;

/*! \brief Whether the current appended tokens form a chain (not a tree). */
bool is_chain = true;
/*! \brief The token tree parent pointer array of the current appended tokens. */
std::vector<int32_t> token_tree_parent_ptr;
/*! \brief The depth of each node in the token tree. */
std::vector<int32_t> token_tree_node_depths;
/*!
* \brief A boolean denoting whether the accepted token tree indices of
* this sequence are committed
*/
bool accepted_indices_committed = true;

explicit Sequence(std::vector<Block>* global_block_pool, int32_t last_block_idx) {
++global_block_pool->at(last_block_idx).external_ref_cnt;
this->last_block_idx = last_block_idx;
Expand Down Expand Up @@ -879,10 +891,6 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
IntTuple cur_seq_ids_;
/*! \brief The append lengths of the sequences in the current round of forwarding. */
IntTuple cur_append_lengths_;
/*! \brief The token tree parent array of the sequences in the current round of forwarding. */
IntTuple cur_token_tree_parent_ptr_{nullptr};
/*! \brief The depth of each node in the token tree, for the sequences in the current batch. */
std::vector<std::vector<int32_t>> cur_token_tree_node_depths_;
/*! \brief Whether the current batch of sequences are token chains (not token trees). */
bool is_chain_;
/*! \brief Number of fork depth in the current round of forward. */
Expand Down Expand Up @@ -1187,6 +1195,9 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
<< "The forked position should be non-negative, or -1 for last position as default.";
CHECK_LE(fork_pos, parent_it->second.seq_length)
<< "The forked position should not exceed the total length of parent sequence.";
CHECK(parent_it->second.accepted_indices_committed)
<< "The parent sequence's token tree computed in the last round of forward has not been "
"committed with accepted nodes.";

int32_t child_block_idx = GetFreeBlock();
if (fork_pos == -1 || fork_pos == parent_it->second.seq_length) {
Expand Down Expand Up @@ -1434,25 +1445,13 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {

void BeginForward(const IntTuple& seq_ids, const IntTuple& append_lengths,
const Optional<IntTuple>& opt_token_tree_parent_ptr) final {
CHECK(!cur_token_tree_parent_ptr_.defined())
<< "The last round of forward which involves token tree has not been committed. Please "
"call \"CommitAcceptedTreeNodes\" to commit the accepted tokens.";

CHECK_EQ(seq_ids.size(), append_lengths.size())
<< "The seq_ids size (" << seq_ids.size() << ") and append_lengths size ("
<< append_lengths.size() << ") mismatch.";
cur_batch_size_ = seq_ids.size();
cur_seq_ids_ = seq_ids;
cur_append_lengths_ = append_lengths;

// - Check token tree validity and process the token tree.
is_chain_ = true;
tree_attn_mask_host_.clear();
tree_attn_mn_indptr_host_.clear();
if (opt_token_tree_parent_ptr.defined()) {
is_chain_ = ConstructTokenTreeMask(opt_token_tree_parent_ptr.value());
}

// - Collect sequence/block/page information for attention.
std::vector<Sequence*> sequences;
std::vector<int32_t> last_block_length_before_append;
Expand All @@ -1474,6 +1473,29 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

// - Check token tree validity and process the token tree.
is_chain_ = true;
tree_attn_mask_host_.clear();
tree_attn_mn_indptr_host_.clear();
if (opt_token_tree_parent_ptr.defined()) {
is_chain_ = ConstructTokenTreeMask(sequences, opt_token_tree_parent_ptr.value());
} else {
// The input batch does not form trees. So each sequence in the batch
// is required to have all past accepted tokens committed.
for (int i = 0; i < cur_batch_size_; ++i) {
Sequence* sequence = sequences[i];
CHECK(sequence->accepted_indices_committed)
<< "The input batch does not form a tree, in which case the sequences in the input "
"batch are expected to have their accepted tokens token tree nodes committed. "
"Please invoke CommitAcceptedTokenTreeNodes for sequence "
<< seq_ids[i];
sequence->is_chain = true;
sequence->token_tree_parent_ptr.clear();
sequence->token_tree_node_depths.clear();
}
is_chain_ = true;
}

std::vector<std::vector<int32_t>> block_ids_on_depths = GetBlockIdsOnDepth(sequences);
num_depths_ = block_ids_on_depths.size();
ICHECK_LE(num_depths_, kPagedKVCacheMaxBlockDepth);
Expand Down Expand Up @@ -1559,7 +1581,7 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
for (int64_t pos = 0; pos < append_length; ++pos) {
q_rope_position_map_host_.push_back(
k_ragged_rope_pos_offset_host_[i] +
(is_chain_ ? pos : cur_token_tree_node_depths_[i][pos]));
(is_chain_ ? pos : sequences[i]->token_tree_node_depths[pos]));

int32_t pos_in_block = block.seq_length - append_length + pos;
if (last_block_length_before_append[i] + pos < block.sink_length) {
Expand Down Expand Up @@ -1649,19 +1671,26 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
}
}

void CommitAcceptedTokenTreeNodes(const IntTuple& leaf_indices) final {
CHECK_NE(cur_batch_size_, -1)
<< "Cannot commit accepted token tree nodes since BeginForward is not invoked.";
CHECK_EQ(leaf_indices.size(), cur_batch_size_)
<< "The number of input leaf indices does not equal to the current batch size.";
void CommitAcceptedTokenTreeNodes(const IntTuple& seq_ids, const IntTuple& leaf_indices) final {
CHECK_EQ(seq_ids.size(), leaf_indices.size())
<< "The given seq_ids and leaf_indices have different size.";
int num_seq_to_commit = seq_ids.size();

for (int i = 0; i < cur_batch_size_; ++i) {
CHECK_GE(leaf_indices[i], 0)
<< "Invalid tree index " << leaf_indices[i] << " which is negative";
CHECK_LT(leaf_indices[i], cur_append_lengths_[i])
std::vector<Sequence*> sequences;
sequences.reserve(num_seq_to_commit);
for (int i = 0; i < num_seq_to_commit; ++i) {
auto it = seq_map_.find(seq_ids[i]);
CHECK(it != seq_map_.end()) << "The sequence \"" << seq_ids[i]
<< "\" cannot be found in KV cache.";
sequences.push_back(&it->second);
CHECK(!it->second.accepted_indices_committed)
<< "The accepted nodes of sequence " << seq_ids[i] << " are already committed.";
CHECK_GE(leaf_indices[i], -1)
<< "Invalid tree index " << leaf_indices[i] << " which is less than -1";
CHECK_LT(leaf_indices[i], static_cast<int64_t>(it->second.token_tree_parent_ptr.size()))
<< "Invalid tree index " << leaf_indices[i]
<< " which is larger than or equals to the append length " << cur_append_lengths_[i]
<< " of the sequence";
<< " which is larger than or equals to the append length "
<< it->second.token_tree_parent_ptr.size() << " of the sequence";
}

if (!is_chain_) {
Expand All @@ -1670,16 +1699,21 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
commit_copy_dst_pos_in_page_table_host_.clear();
commit_copy_length_indptr_host_.push_back(0);

for (int i = 0; i < cur_batch_size_; ++i) {
for (int i = 0; i < num_seq_to_commit; ++i) {
if (leaf_indices[i] == -1) {
// No node is accepted. All nodes in the token tree need to be popped.
continue;
}

// Get the accepted node path on the token tree.
std::vector<int32_t> path_on_tree;
path_on_tree.reserve(cur_token_tree_node_depths_[i][leaf_indices[i]] + 1);
path_on_tree.reserve(sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
int node = leaf_indices[i];
while (node != -1) {
path_on_tree.push_back(node);
node = cur_token_tree_parent_ptr_[cur_append_lengths_indptr_host_[i] + node];
node = sequences[i]->token_tree_parent_ptr[node];
}
ICHECK_EQ(path_on_tree.size(), cur_token_tree_node_depths_[i][leaf_indices[i]] + 1);
ICHECK_EQ(path_on_tree.size(), sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1);
// Get the destination array (range [0, path_length - 1)) of KV cache copy.
std::vector<int32_t> copy_dst_pos_in_seq;
copy_dst_pos_in_seq.resize(path_on_tree.size());
Expand Down Expand Up @@ -1714,14 +1748,16 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
// Note: Function "PopN" only changes the page table structure and does not
// change the KV cache data. Therefore, we can directly use it, since
// we have already launched all copies.
for (int i = 0; i < cur_batch_size_; ++i) {
for (int i = 0; i < num_seq_to_commit; ++i) {
int64_t length_to_pop =
cur_append_lengths_[i] - cur_token_tree_node_depths_[i][leaf_indices[i]] - 1;
cur_append_lengths_[i] -
(leaf_indices[i] != -1 ? (sequences[i]->token_tree_node_depths[leaf_indices[i]] + 1) : 0);
PopN(cur_seq_ids_[i], length_to_pop);
// Reset the sequence states.
sequences[i]->accepted_indices_committed = true;
sequences[i]->token_tree_parent_ptr.clear();
sequences[i]->token_tree_node_depths.clear();
}

// Reset the token tree.
cur_token_tree_parent_ptr_ = IntTuple{nullptr};
}

NDArray GetQueryPositions() final {
Expand Down Expand Up @@ -1814,70 +1850,77 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj {
return block_idx;
}

bool ConstructTokenTreeMask(const IntTuple& token_tree_parent_ptr) {
bool ConstructTokenTreeMask(const std::vector<Sequence*>& sequences,
const IntTuple& token_tree_parent_ptr) {
// We check if the token tree deteriorates to a chain,
// because chain cases can have simplified attention work flow.
bool is_chain = true;
cur_token_tree_parent_ptr_ = token_tree_parent_ptr;
cur_token_tree_node_depths_.clear();
cur_token_tree_node_depths_.reserve(cur_batch_size_);

int64_t sum_append_length = 0;
int64_t sum_new_append_length = 0;
// - Construct the mn indptr array, which is the indptr of the mask size of each sequence.
tree_attn_mn_indptr_host_.push_back(0);
for (int64_t append_length : cur_append_lengths_) {
sum_append_length += append_length;
ICHECK_EQ(sequences.size(), cur_batch_size_);
ICHECK_EQ(cur_append_lengths_.size(), cur_batch_size_);
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = cur_append_lengths_[i];
// Update the token tree parent pointers.
sequences[i]->token_tree_parent_ptr = {
token_tree_parent_ptr->data + sum_new_append_length,
token_tree_parent_ptr->data + sum_new_append_length + cur_append_lengths_[i]};
sum_new_append_length += cur_append_lengths_[i];

CHECK_LE(append_length, kTreeAttnMaxTreeSize)
<< "The tree size is " << append_length << " which exceeds the maximum tree size limit "
<< kTreeAttnMaxTreeSize;
tree_attn_mn_indptr_host_.push_back(tree_attn_mn_indptr_host_.back() +
static_cast<int32_t>(append_length * append_length));
append_length * append_length);
}
CHECK_EQ(token_tree_parent_ptr.size(), sum_append_length)
<< "Invalid token tree size. The sum of \"append_lengths\" is " << sum_append_length
CHECK_EQ(token_tree_parent_ptr.size(), sum_new_append_length)
<< "Invalid token tree size. The sum of \"append_lengths\" is " << sum_new_append_length
<< " while there are " << token_tree_parent_ptr.size()
<< " elements in \"token_tree_parent_ptr\".";

// - Construct the mask of each sequence.
int processed_pos = 0;
for (int i = 0; i < cur_batch_size_; ++i) {
int64_t append_length = cur_append_lengths_[i];
int64_t tree_size = sequences[i]->token_tree_parent_ptr.size();
std::vector<std::vector<int32_t>> mask;
std::vector<int32_t> depth;
mask.reserve(append_length);
depth.reserve(append_length);
for (int64_t n = 0; n < append_length; ++n) {
CHECK_LT(token_tree_parent_ptr[processed_pos], n)
mask.reserve(tree_size);
depth.reserve(tree_size);
sequences[i]->is_chain = true;
sequences[i]->accepted_indices_committed = false;
for (int64_t n = 0; n < tree_size; ++n) {
CHECK_LT(sequences[i]->token_tree_parent_ptr[n], n)
<< "Invalid token tree. The parent of node " << n << " in tree " << i << " is "
<< token_tree_parent_ptr[processed_pos] << ", which is not smaller than " << n;
CHECK_GE(token_tree_parent_ptr[processed_pos], -1)
<< sequences[i]->token_tree_parent_ptr[n] << ", which is not smaller than " << n;
CHECK_GE(sequences[i]->token_tree_parent_ptr[n], -1)
<< "Invalid token tree. The parent of node " << n << " in tree " << i << " is "
<< token_tree_parent_ptr[processed_pos];
if (token_tree_parent_ptr[processed_pos] != n - 1) {
<< sequences[i]->token_tree_parent_ptr[n];
if (sequences[i]->token_tree_parent_ptr[n] != n - 1) {
// The parent of the current node is not the last node.
// Therefore the tree is not a chain.
sequences[i]->is_chain = false;
is_chain = false;
}

std::vector<int32_t> single_pos_mask;
if (token_tree_parent_ptr[processed_pos] != -1) {
if (sequences[i]->token_tree_parent_ptr[n] != -1) {
// The current node has a parent in the token tree.
single_pos_mask = {mask[token_tree_parent_ptr[processed_pos]].begin(),
mask[token_tree_parent_ptr[processed_pos]].end()};
depth.push_back(depth[token_tree_parent_ptr[processed_pos]] + 1);
single_pos_mask = {mask[sequences[i]->token_tree_parent_ptr[n]].begin(),
mask[sequences[i]->token_tree_parent_ptr[n]].end()};
depth.push_back(depth[sequences[i]->token_tree_parent_ptr[n]] + 1);
} else {
// The current node is root in the token tree.
single_pos_mask.resize(append_length, /*value=*/0);
single_pos_mask.resize(tree_size, /*value=*/0);
depth.push_back(0);
}
single_pos_mask[n] = 1;
mask.push_back(single_pos_mask);
for (int32_t mask_val : single_pos_mask) {
tree_attn_mask_host_.push_back(mask_val);
}

++processed_pos;
}
cur_token_tree_node_depths_.push_back(std::move(depth));
sequences[i]->token_tree_node_depths = std::move(depth);
}

return is_chain;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -438,7 +438,10 @@ def apply_attention(
fend_forward(kv_cache)

if accepted_leaf_indices is not None:
fcommit_accepted_token_tree_nodes(kv_cache, ShapeTuple(accepted_leaf_indices))
seq_ids = [seq_id for seq_id, _ in batch]
fcommit_accepted_token_tree_nodes(
kv_cache, ShapeTuple(seq_ids), ShapeTuple(accepted_leaf_indices)
)
for i, (accepted_leaf_idx, (seq_id, append_length)) in enumerate(
zip(accepted_leaf_indices, batch)
):
Expand All @@ -449,7 +452,7 @@ def apply_attention(
node = token_tree_parent_ptr_list[i][node]
offset = cached_k[seq_id].shape[1] - append_length
length_to_pop = append_length - len(tree_path)
assert 0 <= length_to_pop < append_length
assert 0 <= length_to_pop <= append_length
for dst_pos, src_pos in enumerate(reversed(tree_path)):
if dst_pos == src_pos:
continue
Expand Down Expand Up @@ -773,7 +776,7 @@ def test_paged_attention_kv_cache_tree_attn(kv_cache_and_config):
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8], # chain of length 10
[-1, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12], # chain of length 14
],
accepted_leaf_indices=[2, 6, 6, 4],
accepted_leaf_indices=[2, 6, -1, 4],
)
# Do 5 rounds of decode.
for _ in range(5):
Expand Down

0 comments on commit b87d1f9

Please sign in to comment.