Skip to content

Commit

Permalink
Recompute KV cache for Phi3 when switching from short to long factor (#…
Browse files Browse the repository at this point in the history
…1161)

Recompute KV cache for Phi3 when switching from short to long factor.

Verified that this PR fixes the issue for:
1. Phi3.5 mini
2. Phi3 mini 128K
3. Phi3 small
4. Phi3 medium
  • Loading branch information
ajindal1 authored Jan 8, 2025
1 parent a715113 commit 41c2543
Showing 1 changed file with 14 additions and 0 deletions.
14 changes: 14 additions & 0 deletions src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,20 @@ void Generator::GenerateNextToken() {
ThrowErrorIfSessionTerminated(state_->session_terminated_);
if (search_->GetSequenceLength() == 0 && !computed_logits_)
throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken.");

// TODO: Extend the solution to make it work for batch size > 1, num beams > 1, multimodal and DML
// Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache
// at this stage which is achieved by rewinding to zero and appending the current sequence
// Scenarios where this solution works: Batch size = 1, Num beams = 1, decoder model, EP is either CPU or CUDA
// Scenarios where it doesn't work: Batch size > 1 OR Num beams > 1 OR Multimodal model (like phi3 vision) OR EP is DML
if (search_->params_->BatchBeamSize() == 1) {
if (((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) || ((search_->GetSequenceLength() == 8197) && (model_->config_->model.type == "phi3small"))) {
auto current_seq = cpu_span<int32_t>(GetSequence(0).CopyDeviceToCpu());
RewindToLength(0);
AppendTokens(current_seq);
}
}

if (!computed_logits_) {
auto next_tokens = search_->GetNextTokens();
if (last_action_ == Action::rewound)
Expand Down

0 comments on commit 41c2543

Please sign in to comment.