Skip to content

Commit

Permalink
fix: template issue for tokenizer v3
Browse files Browse the repository at this point in the history
  • Loading branch information
sangjanai committed Jul 3, 2024
1 parent c7c8516 commit 9303456
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 67 deletions.
107 changes: 41 additions & 66 deletions cpp/tensorrt_llm/cortex.tensorrt-llm/src/tensorrt-llm_engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,13 @@ namespace {
// '[', 'INST', ']', '[INST]', ''[, '/' , 'INST',']', '[/INST]', '</s>'
const std::vector<int32_t> kMistral_V0_3_StopWords
= {29560, 17057, 29561, 3, 29560, 29516, 17057, 29561, 4, 2, 3, 4, 8, 9, 10, -1, -1, -1, -1, -1};
const std::string kMistralUserPrompt = "[INST] ";
const std::string kMistralAiPrompt = "[/INST] ";
const std::string kMistralSystemPrompt = "<s>";
const std::unordered_map<std::string, int> kMistralTemplate = {{"[INST]", 3} , {"[/INST]", 4}};

enum class MistralTemplate: int32_t {
kBos = 1,
kEos = 2,
kBeginInst = 3,
kEndInst = 4
};

// TODO(sang) This is fragile, just a temporary solution. Maybe can use a config file or model architect, etc...
bool IsOpenhermes(const std::string& s) {
Expand All @@ -51,27 +54,6 @@ namespace {
}
return true;
}

std::string GetUserPrompt(bool is_openhermes) {
if(is_openhermes) {
return kOhUserPrompt;
}
return kMistralUserPrompt;
}

std::string GetAiPrompt(bool is_openhermes) {
if(is_openhermes) {
return kOhAiPrompt;
}
return kMistralAiPrompt;
}

std::string GetSystemPrompt(bool is_openhermes) {
if(is_openhermes) {
return kOhSystemPrompt;
}
return kMistralSystemPrompt;
}
}
TensorrtllmEngine::~TensorrtllmEngine() {}

Expand All @@ -84,56 +66,22 @@ bool HandleMatch(std::string const& rew_text,
std::function<void(Json::Value&&, Json::Value&&)> cb,
bool is_openhermes) {
if (infer_state->IsComplete(is_openhermes)) {
infer_state->rewind_strs.clear();
return false;
}
if (infer_state->stop_word_match_len == 0) {
if ((is_openhermes && rew_text.find('<') != std::string::npos) ||
(!is_openhermes && rew_text.find('[') != std::string::npos)) {
infer_state->stop_word_match_len++; // Move to next state
infer_state->rewind_strs.push_back(rew_text);
return true;
}
} else if (rew_text == infer_state->GetSequence(is_openhermes, infer_state->stop_word_match_len)) {
infer_state->stop_word_match_len++; // Move to next state
infer_state->rewind_strs.push_back(rew_text);
return true;
} else if (infer_state->stop_word_match_len > 0 && rew_text == infer_state->GetSequence(is_openhermes, 0u)) {
infer_state->stop_word_match_len = 1; // Restart from first match if sequence breaks but matches start
// response cache data
for(auto const& s: infer_state->rewind_strs) {
// std::cout << s;
const std::string text_to_stream
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
Json::Value resp_data;
resp_data["data"] = text_to_stream;
Json::Value status;
status["is_done"] = false;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(resp_data));
}
infer_state->rewind_strs.clear();
infer_state->rewind_strs.push_back(rew_text);
return true;
} else {
infer_state->Reset();
// response cache data
for(auto const& s: infer_state->rewind_strs) {
// std::cout << s;
const std::string text_to_stream
= "data: " + tensorrtllm_utils::CreateReturnJson(tensorrtllm_utils::GenerateRandomString(20), "_", s) + "\n\n";
Json::Value resp_data;
resp_data["data"] = text_to_stream;
Json::Value status;
status["is_done"] = false;
status["has_error"] = false;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(resp_data));
}
infer_state->rewind_strs.clear();
return false; // Reset to start if sequence breaks
}
return false;
Expand Down Expand Up @@ -207,9 +155,8 @@ void InferenceThread(
RemoveId(output_idsHostDecode, v);
}
} else {
for(auto const& [_, v]: kMistralTemplate) {
RemoveId(output_idsHostDecode, v);
}
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kBeginInst));
RemoveId(output_idsHostDecode, static_cast<int32_t>(MistralTemplate::kEndInst));
}
std::string text = self->cortex_tokenizer->Decode(output_idsHostDecode);

Expand Down Expand Up @@ -287,19 +234,37 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
data["presence_penalty"] = request.presence_penalty;
Json::Value const& messages = request.messages;

// tokens for Mistral v0.3
// TODO(sang): too much hard code here, need to refactor it soon
std::vector<int32_t> tokens = {static_cast<int32_t>(MistralTemplate::kBos)};

// Format the input from user
int msg_count = 0;
for (auto const& message : messages) {
std::string input_role = message["role"].asString();
std::string role;
if (input_role == "user") {
role = user_prompt_;
std::string content = message["content"].asString();
formatted_input += role + content;
if(!is_openhermes_) {
auto new_tokens = cortex_tokenizer->Encode(content);
new_tokens.insert(new_tokens.begin(), static_cast<int32_t>(MistralTemplate::kBeginInst));
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEndInst));
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
}
}
else if (input_role == "assistant") {
role = ai_prompt_;
std::string content = message["content"].asString();
formatted_input += role + content;
if(!is_openhermes_) {
auto new_tokens = cortex_tokenizer->Encode(content);
if(msg_count == messages.size() - 1) {
new_tokens.push_back(static_cast<int32_t>(MistralTemplate::kEos));
}
tokens.insert(tokens.end(), new_tokens.begin(), new_tokens.end());
}
}
else if (input_role == "system") {
role = system_prompt_;
Expand All @@ -311,14 +276,21 @@ void TensorrtllmEngine::HandleChatCompletion(std::shared_ptr<Json::Value> json_b
std::string content = message["content"].asString();
formatted_input += role + content;
}
msg_count++;
}
formatted_input += ai_prompt_;
// LOG_INFO << formatted_input;
// Format the input from user

std::shared_ptr<InferenceState> infer_state = std::make_shared<InferenceState>();

std::vector<int32_t> input_ids_host = cortex_tokenizer->Encode(formatted_input);
std::vector<int32_t> input_ids_host;
if(is_openhermes_) {
input_ids_host = cortex_tokenizer->Encode(formatted_input);
} else {
input_ids_host = tokens;
}

int const input_len = input_ids_host.size();
int const outputLen = request.max_tokens - input_len;

Expand Down Expand Up @@ -397,9 +369,12 @@ void TensorrtllmEngine::LoadModel(std::shared_ptr<Json::Value> json_body, std::f
is_openhermes_ = IsOpenhermes(request.model_path);

int ctx_len = request.ctx_len;
user_prompt_ = request.user_prompt.empty() ? GetUserPrompt(is_openhermes_) : request.user_prompt;
ai_prompt_ = request.ai_prompt.empty() ? GetAiPrompt(is_openhermes_) : request.ai_prompt;
system_prompt_ = request.system_prompt.empty() ? GetSystemPrompt(is_openhermes_) : request.system_prompt;
// We only support 2 models for now, it is ugly but it works :(
if(is_openhermes_) {
user_prompt_ = request.user_prompt.empty() ? kOhUserPrompt : request.user_prompt;
ai_prompt_ = request.ai_prompt.empty() ? kOhAiPrompt : request.ai_prompt;
system_prompt_ = request.system_prompt.empty() ? kOhSystemPrompt : request.system_prompt;
}
model_id_ = GetModelId(*json_body);

logger_ = std::make_shared<TllmLogger>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ struct InferenceState {
std::vector<std::string> sequence_openhermes = {"<", "|", "im", "_", "end", "|", ">"};
std::vector<std::string> sequence_mistral = {"[", "INST", "]"};
int token_gen_count = 0;
std::vector<std::string> rewind_strs;

void Reset() {
stop_word_match_len = 0;
Expand Down

0 comments on commit 9303456

Please sign in to comment.