diff --git a/src/generators.cpp b/src/generators.cpp index 9415cb1db..cb50f268f 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -389,6 +389,20 @@ void Generator::GenerateNextToken() { ThrowErrorIfSessionTerminated(state_->session_terminated_); if (search_->GetSequenceLength() == 0 && !computed_logits_) throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken."); + + // TODO: Extend the solution to make it work for batch size > 1, num beams > 1, multimodal and DML + // Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache + // at this stage which is achieved by rewinding to zero and appending the current sequence + // Scenarios where this solution works: Batch size = 1, Num beams = 1, decoder model, EP is either CPU or CUDA + // Scenarios where it doesn't work: Batch size > 1 OR Num beams > 1 OR Multimodal model (like phi3 vision) OR EP is DML + if (search_->params_->BatchBeamSize() == 1) { + if (((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) || ((search_->GetSequenceLength() == 8197) && (model_->config_->model.type == "phi3small"))) { + auto current_seq = cpu_span(GetSequence(0).CopyDeviceToCpu()); + RewindToLength(0); + AppendTokens(current_seq); + } + } + if (!computed_logits_) { auto next_tokens = search_->GetNextTokens(); if (last_action_ == Action::rewound)