diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc index bda5f263b..2d2e8ed53 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc @@ -39,10 +39,13 @@ namespace { // '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '' const std::vector kMistral_V0_3_StopWords = {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1}; - const std::string kMistralUserPrompt = "[INST] "; - const std::string kMistralAiPrompt = "[/INST] "; - const std::string kMistralSystemPrompt = ""; - const std::unordered_map kMistralTemplate = {{"[INST]", 3} , {"[/INST]", 4}}; + + enum class MistralTemplate: int32_t { + kBos = 1, + kEos = 2, + kBeginInst = 3, + kEndInst = 4 + }; // TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc... bool IsOpenhermes(const std::string& s) { @@ -51,27 +54,6 @@ namespace { } return true; } - - std::string GetUserPrompt(bool is_openhermes) { - if(is_openhermes) { - return kOhUserPrompt; - } - return kMistralUserPrompt; - } - - std::string GetAiPrompt(bool is_openhermes) { - if(is_openhermes) { - return kOhAiPrompt; - } - return kMistralAiPrompt; - } - - std::string GetSystemPrompt(bool is_openhermes) { - if(is_openhermes) { - return kOhSystemPrompt; - } - return kMistralSystemPrompt; - } } TensorrtllmEngine::~TensorrtllmEngine() {} @@ -84,56 +66,22 @@ bool HandleMatch(std::string const& rew_text, std::function cb, bool is_openhermes) { if (infer_state->IsComplete(is_openhermes)) { - infer_state->rewind_strs.clear(); return false; } if (infer_state->stop_word_match_len == 0) { if ((is_openhermes && rew_text.find('<') != std::string::npos) || (!is_openhermes && rew_text.find('[') != std::string::npos)) { infer_state->stop_word_match_len++; // Move to next state - infer_state->rewind_strs.push_back(rew_text); return true; } } else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) { infer_state->stop_word_match_len++; // Move to next state - infer_state->rewind_strs.push_back(rew_text); return true; } else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) { infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start - // response cache data - for(auto const& s: infer_state->rewind_strs) { - // std::cout << s; - const std::string text_to_stream - = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n"; - Json::Value resp_data; - resp_data["data"] = text_to_stream; - Json::Value status; - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = k200OK; - cb(std::move(status), std::move(resp_data)); - } - infer_state->rewind_strs.clear(); - infer_state->rewind_strs.push_back(rew_text); return true; } else { infer_state->Reset(); - // response cache data - for(auto const& s: infer_state->rewind_strs) { - // std::cout << s; - const std::string text_to_stream - = "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n"; - Json::Value resp_data; - resp_data["data"] = text_to_stream; - Json::Value status; - status["is_done"] = false; - status["has_error"] = false; - status["is_stream"] = true; - status["status_code"] = k200OK; - cb(std::move(status), std::move(resp_data)); - } - infer_state->rewind_strs.clear(); return false; // Reset to start if sequence breaks } return false; @@ -207,9 +155,8 @@ void InferenceThread( RemoveId(output_idsHostDecode, v); } } else { - for(auto const& [_, v]: kMistralTemplate) { - RemoveId(output_idsHostDecode, v); - } + RemoveId(output_idsHostDecode, static_cast(MistralTemplate::kBeginInst)); + RemoveId(output_idsHostDecode, static_cast(MistralTemplate::kEndInst)); } std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode); @@ -287,7 +234,12 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b data["presence_penalty"] = request.presence_penalty; Json::Value const& messages = request.messages; + // tokens for Mistral v0.3 + // TODO(sang): too much hard code here, need to refactor it soon + std::vector tokens = {static_cast(MistralTemplate::kBos)}; + // Format the input from user + int msg_count = 0; for (auto const& message : messages) { std::string input_role = message["role"].asString(); std::string role; @@ -295,11 +247,24 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b role = user_prompt_; std::string content = message["content"].asString(); formatted_input += role + content; + if(!is_openhermes_) { + auto new_tokens = cortex_tokenizer->Encode(content); + new_tokens.insert(new_tokens.begin(), static_cast(MistralTemplate::kBeginInst)); + new_tokens.push_back(static_cast(MistralTemplate::kEndInst)); + tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end()); + } } else if (input_role == "assistant") { role = ai_prompt_; std::string content = message["content"].asString(); formatted_input += role + content; + if(!is_openhermes_) { + auto new_tokens = cortex_tokenizer->Encode(content); + if(msg_count == messages.size() - 1) { + new_tokens.push_back(static_cast(MistralTemplate::kEos)); + } + tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end()); + } } else if (input_role == "system") { role = system_prompt_; @@ -311,6 +276,7 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b std::string content = message["content"].asString(); formatted_input += role + content; } + msg_count++; } formatted_input += ai_prompt_; // LOG_INFO << formatted_input; @@ -318,7 +284,13 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr json_b std::shared_ptr infer_state = std::make_shared(); - std::vector input_ids_host = cortex_tokenizer->Encode(formatted_input); + std::vector input_ids_host; + if(is_openhermes_) { + input_ids_host = cortex_tokenizer->Encode(formatted_input); + } else { + input_ids_host = tokens; + } + int const input_len = input_ids_host.size(); int const outputLen = request.max_tokens - input_len; @@ -397,9 +369,12 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr json_body, std::f is_openhermes_ = IsOpenhermes(request.model_path); int ctx_len = request.ctx_len; - user_prompt_ = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt; - ai_prompt_ = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt; - system_prompt_ = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt; + // We only support 2 models for now, it is ugly but it works :( + if(is_openhermes_) { + user_prompt_ = request.user_prompt.empty() ? kOhUserPrompt : request.user_prompt; + ai_prompt_ = request.ai_prompt.empty() ? kOhAiPrompt : request.ai_prompt; + system_prompt_ = request.system_prompt.empty() ? kOhSystemPrompt : request.system_prompt; + } model_id_ = GetModelId(*json_body); logger_ = std::make_shared(); diff --git a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h index 6f1200790..dd0036c53 100644 --- a/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h +++ b/cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.h @@ -75,7 +75,6 @@ struct InferenceState { std::vector sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"}; std::vector sequence_mistral = {"[", "INST", "]"}; int token_gen_count = 0; - std::vector rewind_strs; void Reset() { stop_word_match_len = 0;