From 9998c4855e4bdbcc14cfe2614b0f1358f6fbc17e Mon Sep 17 00:00:00 2001 From: sangjanai Date: Tue, 2 Jul 2024 09:57:18 +0000 Subject: [PATCH 1/4] fix: support Mistral v0.3 --- .../src/models/load_model_request.h | 12 +- .../src/tensorrt-llm_engine.cc | 112 ++++++++++++++---- .../src/tensorrt-llm_engine.h | 21 +++- 3 files changed, 111 insertions(+), 34 deletions(-) diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/models/load_model_request.h b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/models/load_model_request.h index 7658ef762..f6deccf6f 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/models/load_model_request.h +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/models/load_model_request.h @@ -8,9 +8,9 @@ struct LoadModelRequest { int ctx_len = 2048; int n_parallel = 1; std::string model_path; - std::string user_prompt = "<|im_end|>\n<|im_start|>user\n"; - std::string ai_prompt = "<|im_end|>\n<|im_start|>user\n"; - std::string system_prompt = "<|im_end|>\n<|im_start|>user\n"; + std::string user_prompt = ""; + std::string ai_prompt = ""; + std::string system_prompt = ""; }; inline LoadModelRequest fromJson(std::shared_ptr json_body) { @@ -19,9 +19,9 @@ inline LoadModelRequest fromJson(std::shared_ptr json_body) { request.ctx_len = json_body->get("ctx_len", 2048).asInt(); request.n_parallel = json_body->get("n_parallel", 1).asInt(); request.model_path = json_body->get("model_path", "").asString(); - request.user_prompt = json_body->get("user_prompt", "<|im_end|>\n<|im_start|>user\n").asString(); - request.ai_prompt = json_body->get("ai_prompt", "<|im_end|>\n<|im_start|>assistant\n").asString(); - request.system_prompt = json_body->get("system_prompt", "<|im_start|>system\n").asString(); + request.user_prompt = json_body->get("user_prompt", "").asString(); + request.ai_prompt = json_body->get("ai_prompt", "").asString(); + request.system_prompt = json_body->get("system_prompt", "").asString(); } return request; } diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc index a11ba3361..9821e400b 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc @@ -21,35 +21,82 @@ using json = nlohmann::json; using namespace tensorrtllm; +namespace { + constexpr const int k200OK = 200; + constexpr const int k400BadRequest = 400; + constexpr const int k409Conflict = 409; + constexpr const int k500InternalServerError = 500; + + // https://nvidia.github.io/TensorRT-LLM/_cpp_gen/runtime.html#generationinput-h + // stopWordsList + // 'im', '_' , 'end', '', '<|im_end|>' + const std::vector kOpenhermesStopWords = {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1}; + const std::string kOhUserPrompt = "<|im_end|>\n<|im_start|>user\n"; + const std::string kOhAiPrompt = "<|im_end|>\n<|im_start|>assistant\n"; + const std::string kOhSystemPrompt = "<|im_start|>system\n"; + const std::unordered_map kOpenhermesTemplate = {{"<|im_end|>", 32000} , {"<|im_start|>", 32001}}; + + // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '' + const std::vector kMistral_V0_3_StopWords + = {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1}; + const std::string kMistralUserPrompt = "[INST] "; + const std::string kMistralAiPrompt = "[/INST] "; + const std::string kMistralSystemPrompt = ""; + const std::unordered_map kMistralTemplate = {{"[INST]", 3} , {"[/INST]", 4}}; + + // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc... + bool IsOpenhermes(const std::string& s) { + if (s.find("mistral") != std::string::npos || s.find("Mistral") != std::string::npos) { + return false; + } + return true; + } + + std::string GetUserPrompt(bool is_openhermes) { + if(is_openhermes) { + return kOhUserPrompt; + } + return kMistralUserPrompt; + } -constexpr const int k200OK = 200; -constexpr const int k400BadRequest = 400; -constexpr const int k409Conflict = 409; -constexpr const int k500InternalServerError = 500; + std::string GetAiPrompt(bool is_openhermes) { + if(is_openhermes) { + return kOhAiPrompt; + } + return kMistralAiPrompt; + } + std::string GetSystemPrompt(bool is_openhermes) { + if(is_openhermes) { + return kOhSystemPrompt; + } + return kMistralSystemPrompt; + } +} TensorrtllmEngine::~TensorrtllmEngine() {} void RemoveId(std::vector& vec, int id) { vec.erase(std::remove(vec.begin(), vec.end(), id), vec.end()); } -bool HandleMatch(std::string const& rew_text, std::shared_ptr infer_state) { - if (infer_state->IsComplete()) { +bool HandleMatch(std::string const& rew_text, std::shared_ptr infer_state, bool is_openhermes) { + if (infer_state->IsComplete(is_openhermes)) { return false; } if (infer_state->stop_word_match_len == 0) { - if (rew_text.find('<') != std::string::npos) { // Found "<" anywhere in the text + if ((is_openhermes && rew_text.find('<') != std::string::npos) || + (!is_openhermes && rew_text.find('[') != std::string::npos)) { infer_state->stop_word_match_len++; // Move to next state infer_state->prev_text = rew_text; return true; } } - else if (rew_text == infer_state->sequence[infer_state->stop_word_match_len]) { + else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) { infer_state->stop_word_match_len++; // Move to next state infer_state->prev_text = rew_text; return true; } - else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->sequence[0]) { + else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) { infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start infer_state->prev_text = rew_text; return true; @@ -67,9 +114,11 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorSingleStopWordList(int st } GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList() { - std::vector stop_words_tokens - = {321, 28730, 416, 2, 32000, 3, 4, 5, -1, -1}; // Extend with -1 for increased length - return gpt_session->getBufferManager().copyFrom(stop_words_tokens, ITensor::makeShape({1, 2, 5}), MemoryType::kGPU); + if(is_openhermes_) { + return gpt_session->getBufferManager().copyFrom(kOpenhermesStopWords, ITensor::makeShape({1, 2, static_cast(kOpenhermesStopWords.size()/2)}), MemoryType::kGPU); + } else { + return gpt_session->getBufferManager().copyFrom(kMistral_V0_3_StopWords, ITensor::makeShape({1, 2, static_cast(kMistral_V0_3_StopWords.size()/2)}), MemoryType::kGPU); + } } GenerationInput TensorrtllmEngine::CreateGenerationInput(std::vector input_ids_host) { @@ -102,7 +151,7 @@ void InferenceThread( TensorrtllmEngine* self, SamplingConfig sampling_config, int input_len, - int outputLen) { + int outputLen, bool is_openhermes) { // Input preparation LOG_INFO << "Inference thread started"; @@ -110,9 +159,9 @@ void InferenceThread( GenerationOutput generation_output = self->CreateGenerationOutput(); // Define the callback to stream each generated token - generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output]( + generation_output.onTokenGenerated = [&infer_state, input_len, outputLen, self, &generation_output, is_openhermes]( GenerationOutput::TensorPtr const& output_ids, SizeType step, bool finished) { - LOG_INFO << "Generating tokenizer in thread"; + // LOG_INFO << "Generating tokenizer in thread"; // Assuming the shape of output_ids tensor is (1, 1, 160), where 160 is the number of tokens int output_length = output_ids->getShape().d[2]; // Get the length of output IDs based on the tensor shape // Copy output IDs from GPU to host for printing @@ -120,9 +169,17 @@ void InferenceThread( self->gpt_session->getBufferManager().copy(*output_ids, output_idsHost.data(), MemoryType::kCPU); // Find the last non-zero value in the output IDs starting from the end of the input sequence std::vector output_idsHostDecode(output_idsHost.begin() + input_len, output_idsHost.end()); + RemoveId(output_idsHostDecode, 0); - RemoveId(output_idsHostDecode, 32000); - RemoveId(output_idsHostDecode, 32001); + if(is_openhermes) { + for(auto const& [_, v]: kOpenhermesTemplate) { + RemoveId(output_idsHostDecode, v); + } + } else { + for(auto const& [_, v]: kMistralTemplate) { + RemoveId(output_idsHostDecode, v); + } + } std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode); if (infer_state->prev_pos >= 0 && infer_state->prev_pos < text.size()) { @@ -225,6 +282,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b } } formatted_input += ai_prompt; + // LOG_INFO << formatted_input; // Format the input from user std::shared_ptr infer_state = std::make_shared(); @@ -243,23 +301,25 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b sampling_config.repetitionPenalty = std::vector{request.frequency_penalty}; // Input preparation - std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen); + std::thread inference_thread(InferenceThread, infer_state, input_ids_host, callback, this, sampling_config, input_len, outputLen, is_openhermes_); inference_thread.detach(); // Detach the thread to allow it to run independently - q_->runTaskInQueue([cb = std::move(callback), infer_state]() { + q_->runTaskInQueue([this, cb = std::move(callback), infer_state]() { + // std::string res_str; LOG_INFO << "Preparing to run inference task queue..."; while (true) { // Continuously check if the queue is not empty std::unique_lock lock(infer_state->queue_mutex); // Lock the queue for exclusive access if (!infer_state->texts_to_stream.empty()) { std::string rew_text = infer_state->texts_to_stream.front(); + // res_str += rew_text; infer_state->texts_to_stream.pop(); - if (HandleMatch(rew_text, infer_state) && rew_text != "[DONE]") { + if (HandleMatch(rew_text, infer_state, is_openhermes_) && rew_text != "[DONE]") { continue; }; if (rew_text == "[DONE]") { const std::string str - = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", "", "stop") + = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, "", "stop") + "\n\n" + "data: [DONE]" + "\n\n"; infer_state->is_finished = true; @@ -275,7 +335,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b break; } const std::string text_to_stream - = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", rew_text) + "\n\n"; + = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, rew_text) + "\n\n"; lock.unlock(); // Unlock as soon as possible infer_state->prev_text = rew_text; @@ -293,6 +353,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b lock.unlock(); } } + // LOG_INFO << res_str; }); LOG_INFO << "Inference completed"; @@ -302,11 +363,12 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b void TensorrtllmEngine::LoadModel(std::shared_ptr json_body, std::function&& callback) { model::LoadModelRequest request = model::fromJson(json_body); std::filesystem::path model_dir = request.model_path; + is_openhermes_ = IsOpenhermes(request.model_path); int ctx_len = request.ctx_len; - this->user_prompt = request.user_prompt; - this->ai_prompt = request.ai_prompt; - this->system_prompt = request.system_prompt; + this->user_prompt = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt; + this->ai_prompt = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt; + this->system_prompt = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt; this->model_id_ = GetModelId(*json_body); logger = std::make_shared(); diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h index cc971f7eb..a0dfe4051 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h @@ -73,7 +73,8 @@ struct InferenceState { std::queue texts_to_stream; std::mutex queue_mutex; // Mutex to protect access to textsToStream size_t stop_word_match_len = 0; - std::vector sequence{"<", "|", "im", "_", "end", "|", ">"}; + std::vector sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"}; + std::vector sequence_mistral = {"[", "INST", "]"}; int token_gen_count = 0; void Reset() { @@ -81,8 +82,21 @@ struct InferenceState { prev_text = ""; } - bool IsComplete() const { - return stop_word_match_len >= sequence.size(); + bool IsComplete(bool is_openhermes) const { + if(is_openhermes) { + return stop_word_match_len >= sequence_openhermes.size(); + } else { + return stop_word_match_len >= sequence_mistral.size(); + } + } + + const std::string& GetSequence(bool is_openhermes, size_t index) { + if(is_openhermes) { + return sequence_openhermes[index]; + } else { + return sequence_mistral[index]; + } + } }; @@ -138,6 +152,7 @@ class TensorrtllmEngine : public EngineI { uint64_t start_time_; std::atomic model_loaded_; std::unique_ptr q_; + bool is_openhermes_ = true; }; } // namespace inferences From 3ac131ac5834dc8695c9ff3288b87e2d17ebb368 Mon Sep 17 00:00:00 2001 From: sangjanai Date: Tue, 2 Jul 2024 10:10:15 +0000 Subject: [PATCH 2/4] refactor: rename class variables --- .../src/tensorrt-llm_engine.cc | 48 +++++++++---------- .../src/tensorrt-llm_engine.h | 17 ++++--- 2 files changed, 32 insertions(+), 33 deletions(-) diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc index 9821e400b..ea471a5c1 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc @@ -123,12 +123,12 @@ GenerationInput::TensorPtr TensorrtllmEngine::GetTensorChatMLStopWordList() { GenerationInput TensorrtllmEngine::CreateGenerationInput(std::vector input_ids_host) { int input_len = input_ids_host.size(); - std::vector input_lengths_host(batchSize, input_len); + std::vector input_lengths_host(batch_size_, input_len); GenerationInput::TensorPtr input_lengths - = gpt_session->getBufferManager().copyFrom(input_lengths_host, ITensor::makeShape({batchSize}), MemoryType::kGPU); + = gpt_session->getBufferManager().copyFrom(input_lengths_host, ITensor::makeShape({batch_size_}), MemoryType::kGPU); GenerationInput::TensorPtr input_ids = gpt_session->getBufferManager().copyFrom( - input_ids_host, ITensor::makeShape({batchSize, input_len}), MemoryType::kGPU); - GenerationInput generation_input{0, 0, input_ids, input_lengths, model_config->usePackedInput()}; + input_ids_host, ITensor::makeShape({batch_size_, input_len}), MemoryType::kGPU); + GenerationInput generation_input{0, 0, input_ids, input_lengths, model_config_->usePackedInput()}; generation_input.stopWordsList = GetTensorChatMLStopWordList(); LOG_INFO << "Create generation input successfully"; @@ -249,7 +249,7 @@ bool TensorrtllmEngine::CheckModelLoaded(std::function json_body, std::function&& callback) { inferences::ChatCompletionRequest request = inferences::fromJson(json_body); - std::string formatted_input = pre_prompt; + std::string formatted_input = pre_prompt_; nlohmann::json data; // data["stream"] = completion.stream; // data["n_predict"] = completion.max_tokens; @@ -261,17 +261,17 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b std::string input_role = message["role"].asString(); std::string role; if (input_role == "user") { - role = user_prompt; + role = user_prompt_; std::string content = message["content"].asString(); formatted_input += role + content; } else if (input_role == "assistant") { - role = ai_prompt; + role = ai_prompt_; std::string content = message["content"].asString(); formatted_input += role + content; } else if (input_role == "system") { - role = system_prompt; + role = system_prompt_; std::string content = message["content"].asString(); formatted_input = role + content + formatted_input; } @@ -281,7 +281,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b formatted_input += role + content; } } - formatted_input += ai_prompt; + formatted_input += ai_prompt_; // LOG_INFO << formatted_input; // Format the input from user @@ -366,14 +366,14 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr json_body, std::f is_openhermes_ = IsOpenhermes(request.model_path); int ctx_len = request.ctx_len; - this->user_prompt = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt; - this->ai_prompt = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt; - this->system_prompt = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt; - this->model_id_ = GetModelId(*json_body); + user_prompt_ = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt; + ai_prompt_ = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt; + system_prompt_ = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt; + model_id_ = GetModelId(*json_body); - logger = std::make_shared(); - logger->setLevel(nvinfer1::ILogger::Severity::kINFO); - initTrtLlmPlugins(logger.get()); + logger_ = std::make_shared(); + logger_->setLevel(nvinfer1::ILogger::Severity::kINFO); + initTrtLlmPlugins(logger_.get()); std::filesystem::path tokenizer_model_name = model_dir / "tokenizer.model"; cortex_tokenizer = std::make_unique(tokenizer_model_name.string()); @@ -382,20 +382,20 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr json_body, std::f std::filesystem::path json_file_name = model_dir / "config.json"; auto json = GptJsonConfig::parse(json_file_name); auto config = json.getModelConfig(); - model_config = std::make_unique(config); + model_config_ = std::make_unique(config); auto world_config = WorldConfig::mpi(1, json.getTensorParallelism(), json.getPipelineParallelism()); LOG_INFO << "Loaded config from " << json_file_name.string(); // auto dtype = model_config->getDataType(); // Currently doing fixed session config - session_config.maxBatchSize = batchSize; - session_config.maxBeamWidth = 1; // Fixed for simplicity - session_config.maxSequenceLength = ctx_len; - session_config.cudaGraphMode = true; // Fixed for simplicity + session_config_.maxBatchSize = batch_size_; + session_config_.maxBeamWidth = 1; // Fixed for simplicity + session_config_.maxSequenceLength = ctx_len; + session_config_.cudaGraphMode = true; // Fixed for simplicity // Init gpt_session auto model_path = model_dir / json.engineFilename(world_config, model_id_); - gpt_session = std::make_unique(session_config, *model_config, world_config, model_path.string(), logger); + gpt_session = std::make_unique(session_config_, *model_config_, world_config, model_path.string(), logger_); model_loaded_ = true; if (q_ == nullptr) { @@ -427,8 +427,8 @@ void TensorrtllmEngine::UnloadModel(std::shared_ptr json_body, std: gpt_session.reset(); cortex_tokenizer.reset(); q_.reset(); - model_config.reset(); - logger.reset(); + model_config_.reset(); + logger_.reset(); model_loaded_ = false; Json::Value json_resp; diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h index a0dfe4051..d9b9a6162 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h @@ -139,15 +139,14 @@ class TensorrtllmEngine : public EngineI { bool CheckModelLoaded( std::function& callback); - GptSession::Config session_config{1, 1, 1}; - SamplingConfig sampling_config{1}; - std::unique_ptr model_config; - std::shared_ptr logger; - std::string user_prompt; - std::string ai_prompt; - std::string system_prompt; - std::string pre_prompt; - int batchSize = 1; + GptSession::Config session_config_{1, 1, 1}; + std::unique_ptr model_config_; + std::shared_ptr logger_; + std::string user_prompt_; + std::string ai_prompt_; + std::string system_prompt_; + std::string pre_prompt_; + int batch_size_ = 1; std::string model_id_; uint64_t start_time_; std::atomic model_loaded_; From c7c85166a5ca099521f465cc8d8a2a63bebc2b7e Mon Sep 17 00:00:00 2001 From: sangjanai Date: Tue, 2 Jul 2024 14:39:52 +0000 Subject: [PATCH 3/4] fix: rewind text if does not match pattern --- .../src/tensorrt-llm_engine.cc | 55 +++++++++++++++---- .../src/tensorrt-llm_engine.h | 5 +- 2 files changed, 45 insertions(+), 15 deletions(-) diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc index ea471a5c1..bda5f263b 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc @@ -79,30 +79,61 @@ void RemoveId(std::vector& vec, int id) { vec.erase(std::remove(vec.begin(), vec.end(), id), vec.end()); } -bool HandleMatch(std::string const& rew_text, std::shared_ptr infer_state, bool is_openhermes) { +bool HandleMatch(std::string const& rew_text, + std::shared_ptr infer_state, + std::function cb, + bool is_openhermes) { if (infer_state->IsComplete(is_openhermes)) { + infer_state->rewind_strs.clear(); return false; } if (infer_state->stop_word_match_len == 0) { if ((is_openhermes && rew_text.find('<') != std::string::npos) || (!is_openhermes && rew_text.find('[') != std::string::npos)) { infer_state->stop_word_match_len++; // Move to next state - infer_state->prev_text = rew_text; + infer_state->rewind_strs.push_back(rew_text); return true; } - } - else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) { + } else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) { infer_state->stop_word_match_len++; // Move to next state - infer_state->prev_text = rew_text; + infer_state->rewind_strs.push_back(rew_text); return true; - } - else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) { + } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) { infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start - infer_state->prev_text = rew_text; + // response cache data + for(auto const& s: infer_state->rewind_strs) { + // std::cout << s; + const std::string text_to_stream + = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n"; + Json::Value resp_data; + resp_data["data"] = text_to_stream; + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = k200OK; + cb(std::move(status), std::move(resp_data)); + } + infer_state->rewind_strs.clear(); + infer_state->rewind_strs.push_back(rew_text); return true; - } - else { + } else { infer_state->Reset(); + // response cache data + for(auto const& s: infer_state->rewind_strs) { + // std::cout << s; + const std::string text_to_stream + = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n"; + Json::Value resp_data; + resp_data["data"] = text_to_stream; + Json::Value status; + status["is_done"] = false; + status["has_error"] = false; + status["is_stream"] = true; + status["status_code"] = k200OK; + cb(std::move(status), std::move(resp_data)); + } + infer_state->rewind_strs.clear(); return false; // Reset to start if sequence breaks } return false; @@ -313,7 +344,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b std::string rew_text = infer_state->texts_to_stream.front(); // res_str += rew_text; infer_state->texts_to_stream.pop(); - if (HandleMatch(rew_text, infer_state, is_openhermes_) && rew_text != "[DONE]") { + if (HandleMatch(rew_text, infer_state, cb, is_openhermes_) && rew_text != "[DONE]") { continue; }; @@ -338,7 +369,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), model_id_, rew_text) + "\n\n"; lock.unlock(); // Unlock as soon as possible - infer_state->prev_text = rew_text; + // std::cout << rew_text; Json::Value resp_data; resp_data["data"] = text_to_stream; diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h index d9b9a6162..6f1200790 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h @@ -68,7 +68,6 @@ class Tokenizer { struct InferenceState { int prev_pos{0}; - std::string prev_text; bool is_finished; std::queue texts_to_stream; std::mutex queue_mutex; // Mutex to protect access to textsToStream @@ -76,10 +75,10 @@ struct InferenceState { std::vector sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"}; std::vector sequence_mistral = {"[", "INST", "]"}; int token_gen_count = 0; + std::vector rewind_strs; void Reset() { - stop_word_match_len = 0; - prev_text = ""; + stop_word_match_len = 0; } bool IsComplete(bool is_openhermes) const { From 9303456022ea78b5db317181fa80cdc2eacd6895 Mon Sep 17 00:00:00 2001 From: sangjanai Date: Wed, 3 Jul 2024 12:53:34 +0000 Subject: [PATCH 4/4] fix: template issue for tokenizer v3 --- .../src/tensorrt-llm_engine.cc | 107 +++++++----------- .../src/tensorrt-llm_engine.h | 1 - 2 files changed, 41 insertions(+), 67 deletions(-) diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc index bda5f263b..2d2e8ed53 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc @@ -39,10 +39,13 @@ namespace { // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '' const std::vector kMistral_V0_3_StopWords = {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1}; - const std::string kMistralUserPrompt = "[INST] "; - const std::string kMistralAiPrompt = "[/INST] "; - const std::string kMistralSystemPrompt = ""; - const std::unordered_map kMistralTemplate = {{"[INST]", 3} , {"[/INST]", 4}}; + + enum class MistralTemplate: int32_t { + kBos = 1, + kEos = 2, + kBeginInst = 3, + kEndInst = 4 + }; // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc... bool IsOpenhermes(const std::string& s) { @@ -51,27 +54,6 @@ namespace { } return true; } - - std::string GetUserPrompt(bool is_openhermes) { - if(is_openhermes) { - return kOhUserPrompt; - } - return kMistralUserPrompt; - } - - std::string GetAiPrompt(bool is_openhermes) { - if(is_openhermes) { - return kOhAiPrompt; - } - return kMistralAiPrompt; - } - - std::string GetSystemPrompt(bool is_openhermes) { - if(is_openhermes) { - return kOhSystemPrompt; - } - return kMistralSystemPrompt; - } } TensorrtllmEngine::~TensorrtllmEngine() {} @@ -84,56 +66,22 @@ bool HandleMatch(std::string const& rew_text, std::function cb, bool is_openhermes) { if (infer_state->IsComplete(is_openhermes)) { - infer_state->rewind_strs.clear(); return false; } if (infer_state->stop_word_match_len == 0) { if ((is_openhermes && rew_text.find('<') != std::string::npos) || (!is_openhermes && rew_text.find('[') != std::string::npos)) { infer_state->stop_word_match_len++; // Move to next state - infer_state->rewind_strs.push_back(rew_text); return true; } } else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) { infer_state->stop_word_match_len++; // Move to next state - infer_state->rewind_strs.push_back(rew_text); return true; } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) { infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start - // response cache data - for(auto const& s: infer_state->rewind_strs) { - // std::cout << s; - const std::string text_to_stream - = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n"; - Json::Value resp_data; - resp_data["data"] = text_to_stream; - Json::Value status; - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = k200OK; - cb(std::move(status), std::move(resp_data)); - } - infer_state->rewind_strs.clear(); - infer_state->rewind_strs.push_back(rew_text); return true; } else { infer_state->Reset(); - // response cache data - for(auto const& s: infer_state->rewind_strs) { - // std::cout << s; - const std::string text_to_stream - = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n"; - Json::Value resp_data; - resp_data["data"] = text_to_stream; - Json::Value status; - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = k200OK; - cb(std::move(status), std::move(resp_data)); - } - infer_state->rewind_strs.clear(); return false; // Reset to start if sequence breaks } return false; @@ -207,9 +155,8 @@ void InferenceThread( RemoveId(output_idsHostDecode, v); } } else { - for(auto const& [_, v]: kMistralTemplate) { - RemoveId(output_idsHostDecode, v); - } + RemoveId(output_idsHostDecode, static_cast(MistralTemplate::kBeginInst)); + RemoveId(output_idsHostDecode, static_cast(MistralTemplate::kEndInst)); } std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode); @@ -287,7 +234,12 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b data["presence_penalty"] = request.presence_penalty; Json::Value const& messages = request.messages; + // tokens for Mistral v0.3 + // TODO(sang): too much hard code here, need to refactor it soon + std::vector tokens = {static_cast(MistralTemplate::kBos)}; + // Format the input from user + int msg_count = 0; for (auto const& message : messages) { std::string input_role = message["role"].asString(); std::string role; @@ -295,11 +247,24 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b role = user_prompt_; std::string content = message["content"].asString(); formatted_input += role + content; + if(!is_openhermes_) { + auto new_tokens = cortex_tokenizer->Encode(content); + new_tokens.insert(new_tokens.begin(), static_cast(MistralTemplate::kBeginInst)); + new_tokens.push_back(static_cast(MistralTemplate::kEndInst)); + tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end()); + } } else if (input_role == "assistant") { role = ai_prompt_; std::string content = message["content"].asString(); formatted_input += role + content; + if(!is_openhermes_) { + auto new_tokens = cortex_tokenizer->Encode(content); + if(msg_count == messages.size() - 1) { + new_tokens.push_back(static_cast(MistralTemplate::kEos)); + } + tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end()); + } } else if (input_role == "system") { role = system_prompt_; @@ -311,6 +276,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b std::string content = message["content"].asString(); formatted_input += role + content; } + msg_count++; } formatted_input += ai_prompt_; // LOG_INFO << formatted_input; @@ -318,7 +284,13 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b std::shared_ptr infer_state = std::make_shared(); - std::vector input_ids_host = cortex_tokenizer->Encode(formatted_input); + std::vector input_ids_host; + if(is_openhermes_) { + input_ids_host = cortex_tokenizer->Encode(formatted_input); + } else { + input_ids_host = tokens; + } + int const input_len = input_ids_host.size(); int const outputLen = request.max_tokens - input_len; @@ -397,9 +369,12 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr json_body, std::f is_openhermes_ = IsOpenhermes(request.model_path); int ctx_len = request.ctx_len; - user_prompt_ = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt; - ai_prompt_ = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt; - system_prompt_ = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt; + // We only support 2 models for now, it is ugly but it works :( + if(is_openhermes_) { + user_prompt_ = request.user_prompt.empty() ? kOhUserPrompt : request.user_prompt; + ai_prompt_ = request.ai_prompt.empty() ? kOhAiPrompt : request.ai_prompt; + system_prompt_ = request.system_prompt.empty() ? kOhSystemPrompt : request.system_prompt; + } model_id_ = GetModelId(*json_body); logger_ = std::make_shared(); diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h index 6f1200790..dd0036c53 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h @@ -75,7 +75,6 @@ struct InferenceState { std::vector sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"}; std::vector sequence_mistral = {"[", "INST", "]"}; int token_gen_count = 0; - std::vector rewind_strs; void Reset() { stop_word_match_len = 0;