From 6833733d049b98a878b877ee050f3f4365f8b07a Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Fri, 20 Dec 2024 11:18:13 -0800 Subject: [PATCH 1/7] recompute KV cache for Phi3 when switching from short to long factor --- src/generators.cpp | 13 +++++++++++++ src/generators.h | 1 + 2 files changed, 14 insertions(+) diff --git a/src/generators.cpp b/src/generators.cpp index eff98f5ca..1def62dbd 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -352,6 +352,19 @@ void Generator::GenerateNextToken() { ThrowErrorIfSessionTerminated(state_->session_terminated_); if (search_->GetSequenceLength() == 0 && !computed_logits_) throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken."); + + // TODO: Extend the solution to make it work for batch size > 1 and num beams > 1 + // Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache + // at this stage which is achieved by rewinding to zero and appending the current sequence + if (model_->config_->model.type == "phi3" && search_->params_->search.batch_size == 1 && params.search.num_beams == 1) { + if (search_->GetSequenceLength() == 4097 && first_switch) { + first_switch = false; + auto current_seq = cpu_span(GetSequence(0).CpuSpan()); + RewindToLength(0); + AppendTokens(current_seq); + } + } + if (!computed_logits_) { auto next_tokens = search_->GetNextTokens(); if (last_action_ == Action::rewound) diff --git a/src/generators.h b/src/generators.h index 0b6dc7cfc..38c695987 100644 --- a/src/generators.h +++ b/src/generators.h @@ -125,6 +125,7 @@ struct Generator : LeakChecked { std::unique_ptr state_; std::unique_ptr search_; bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio + bool first_switch{true}; private: DeviceSpan AllocateInputIdsOnDevice(const cpu_span input_ids); From d9622deb1dce6fe3c3d838108a0c03a19cc544f2 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Fri, 20 Dec 2024 11:37:31 -0800 Subject: [PATCH 2/7] fix typo --- src/generators.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/generators.cpp b/src/generators.cpp index 1def62dbd..1005b7535 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -356,7 +356,7 @@ void Generator::GenerateNextToken() { // TODO: Extend the solution to make it work for batch size > 1 and num beams > 1 // Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache // at this stage which is achieved by rewinding to zero and appending the current sequence - if (model_->config_->model.type == "phi3" && search_->params_->search.batch_size == 1 && params.search.num_beams == 1) { + if (model_->config_->model.type == "phi3" && search_->params_->search.batch_size == 1 && search_->params_->search.num_beams == 1) { if (search_->GetSequenceLength() == 4097 && first_switch) { first_switch = false; auto current_seq = cpu_span(GetSequence(0).CpuSpan()); From 24f0e11f44a8abf412ba97f04278286df91d8d82 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Fri, 20 Dec 2024 15:04:01 -0800 Subject: [PATCH 3/7] remove first switch --- src/generators.cpp | 3 +-- src/generators.h | 1 - 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 1005b7535..522e6dab7 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -357,8 +357,7 @@ void Generator::GenerateNextToken() { // Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache // at this stage which is achieved by rewinding to zero and appending the current sequence if (model_->config_->model.type == "phi3" && search_->params_->search.batch_size == 1 && search_->params_->search.num_beams == 1) { - if (search_->GetSequenceLength() == 4097 && first_switch) { - first_switch = false; + if (search_->GetSequenceLength() == 4097) { auto current_seq = cpu_span(GetSequence(0).CpuSpan()); RewindToLength(0); AppendTokens(current_seq); diff --git a/src/generators.h b/src/generators.h index 38c695987..0b6dc7cfc 100644 --- a/src/generators.h +++ b/src/generators.h @@ -125,7 +125,6 @@ struct Generator : LeakChecked { std::unique_ptr state_; std::unique_ptr search_; bool computed_logits_{}; // Set to true in ComputeLogits() and false after appending a token to ensure a 1 to 1 call ratio - bool first_switch{true}; private: DeviceSpan AllocateInputIdsOnDevice(const cpu_span input_ids); From f96ea94ab1ac6ac4b37446d780015fdf72ba9237 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Thu, 2 Jan 2025 17:03:55 -0800 Subject: [PATCH 4/7] add different Phi3 models --- src/generators.cpp | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index 522e6dab7..c6e81647b 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -356,8 +356,13 @@ void Generator::GenerateNextToken() { // TODO: Extend the solution to make it work for batch size > 1 and num beams > 1 // Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache // at this stage which is achieved by rewinding to zero and appending the current sequence - if (model_->config_->model.type == "phi3" && search_->params_->search.batch_size == 1 && search_->params_->search.num_beams == 1) { - if (search_->GetSequenceLength() == 4097) { + if (search_->params_->search.batch_size == 1 && search_->params_->search.num_beams == 1) { + if ((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) { + auto current_seq = cpu_span(GetSequence(0).CpuSpan()); + RewindToLength(0); + AppendTokens(current_seq); + } + if ((search_->GetSequenceLength() == 8197) && (model_->config_->model.type == "phi3small")) { auto current_seq = cpu_span(GetSequence(0).CpuSpan()); RewindToLength(0); AppendTokens(current_seq); From 4c799698d5913b9714acf604705143f64666331a Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Mon, 6 Jan 2025 15:14:50 -0800 Subject: [PATCH 5/7] update code and comments --- src/generators.cpp | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index c6e81647b..b1a46789a 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -353,17 +353,15 @@ void Generator::GenerateNextToken() { if (search_->GetSequenceLength() == 0 && !computed_logits_) throw std::runtime_error("GenerateNextToken called with no prior state. Please call AppendTokens, SetLogits, or params.SetInputs before calling GenerateNextToken."); - // TODO: Extend the solution to make it work for batch size > 1 and num beams > 1 + // TODO: Extend the solution to make it work for batch size > 1, num beams > 1, multimodal and DML // Phi3 model switches from short factor to long factor at 4097 (original_max_position_embeddings+1) token, needs Recomputation of Position IDs and KV Cache // at this stage which is achieved by rewinding to zero and appending the current sequence + // Scenarios where this solution works: Batch size = 1, Num beams = 1, decoder model, EP is either CPU or CUDA + // Scenarios where it doesn't work: Batch size > 1 OR Num beams > 1 OR Multimodal model (like phi3 vision) OR EP is DML if (search_->params_->search.batch_size == 1 && search_->params_->search.num_beams == 1) { - if ((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) { - auto current_seq = cpu_span(GetSequence(0).CpuSpan()); - RewindToLength(0); - AppendTokens(current_seq); - } - if ((search_->GetSequenceLength() == 8197) && (model_->config_->model.type == "phi3small")) { - auto current_seq = cpu_span(GetSequence(0).CpuSpan()); + if (((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) || ((search_->GetSequenceLength() == 8197) && (model_->config_->model.type == "phi3small"))) { + // auto current_seq = cpu_span(GetSequence(0).CpuSpan()); + auto current_seq = cpu_span(GetSequence(0).CopyDeviceToCpu()); RewindToLength(0); AppendTokens(current_seq); } From f644e984ba60477bb9623ecd166c964663d807a9 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Tue, 7 Jan 2025 13:21:09 -0800 Subject: [PATCH 6/7] add original context length and code fixes --- src/generators.cpp | 3 +-- src/python/py/models/builder.py | 1 + 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/src/generators.cpp b/src/generators.cpp index b1a46789a..0772dcdfa 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -358,9 +358,8 @@ void Generator::GenerateNextToken() { // at this stage which is achieved by rewinding to zero and appending the current sequence // Scenarios where this solution works: Batch size = 1, Num beams = 1, decoder model, EP is either CPU or CUDA // Scenarios where it doesn't work: Batch size > 1 OR Num beams > 1 OR Multimodal model (like phi3 vision) OR EP is DML - if (search_->params_->search.batch_size == 1 && search_->params_->search.num_beams == 1) { + if (search_->params_->BatchBeamSize() == 1) { if (((search_->GetSequenceLength() == 4097) && (model_->config_->model.type == "phi3" || model_->config_->model.type == "phimoe")) || ((search_->GetSequenceLength() == 8197) && (model_->config_->model.type == "phi3small"))) { - // auto current_seq = cpu_span(GetSequence(0).CpuSpan()); auto current_seq = cpu_span(GetSequence(0).CopyDeviceToCpu()); RewindToLength(0); AppendTokens(current_seq); diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index e027fabda..c0c4a83e0 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -347,6 +347,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "num_key_value_heads": self.num_kv_heads, }, "eos_token_id": config.eos_token_id, + "original_context_length": self.original_context_length, "pad_token_id": config.pad_token_id if hasattr(config, "pad_token_id") and config.pad_token_id is not None else config.eos_token_id[0] if isinstance(config.eos_token_id, list) else config.eos_token_id, "type": self.model_type[ : self.model_type.find("For")].lower(), "vocab_size": self.vocab_size, From 82fd6e0db9d9a6774c0dc033aec82cc78dbc08e7 Mon Sep 17 00:00:00 2001 From: Abhishek Jindal Date: Tue, 7 Jan 2025 16:52:37 -0800 Subject: [PATCH 7/7] revert original context len --- src/python/py/models/builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index c0c4a83e0..e027fabda 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -347,7 +347,6 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "num_key_value_heads": self.num_kv_heads, }, "eos_token_id": config.eos_token_id, - "original_context_length": self.original_context_length, "pad_token_id": config.pad_token_id if hasattr(config, "pad_token_id") and config.pad_token_id is not None else config.eos_token_id[0] if isinstance(config.eos_token_id, list) else config.eos_token_id, "type": self.model_type[ : self.model_type.find("For")].lower(), "vocab_size": self.vocab_size,