Skip to content

Commit

Permalink
feat: support stream_option for OpenAI API compatible (#269)
Browse files Browse the repository at this point in the history
Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored Oct 29, 2024
1 parent 185a7cf commit a990689
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 11 deletions.
7 changes: 7 additions & 0 deletions src/chat_completion_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct ChatCompletionRequest {
bool ignore_eos = false;
int n_probs = 0;
int min_keep = 0;
bool include_usage = false;
std::string grammar;
Json::Value logit_bias = Json::Value(Json::arrayValue);

Expand All @@ -80,6 +81,12 @@ inline ChatCompletionRequest fromJson(std::shared_ptr<Json::Value> jsonBody) {
common_sampler_params default_params;
if (jsonBody) {
completion.stream = (*jsonBody).get("stream", false).asBool();
if(completion.stream) {
auto& stream_options = (*jsonBody)["stream_options"];
if(!stream_options.isNull()) {
completion.include_usage = stream_options.get("include_usage", false).asBool();
}
}
completion.max_tokens = (*jsonBody).get("max_tokens", 500).asInt();
completion.top_p = (*jsonBody).get("top_p", 0.95).asFloat();
completion.temperature = (*jsonBody).get("temperature", 0.8).asFloat();
Expand Down
60 changes: 49 additions & 11 deletions src/llama_engine.cc
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include "llama_engine.h"

#include <chrono>
#include <optional>
#include "json/writer.h"
#include "llama_utils.h"
#include "trantor/utils/Logger.h"
Expand Down Expand Up @@ -32,6 +33,11 @@ struct InferenceState {
InferenceState(LlamaServerContext& l) : llama(l) {}
};

struct Usage {
int prompt_tokens = 0;
int completion_tokens = 0;
};

/**
* This function is to create the smart pointer to InferenceState, hence the
* InferenceState will be persisting even tho the lambda in streaming might go
Expand Down Expand Up @@ -95,7 +101,8 @@ Json::Value CreateFullReturnJson(const std::string& id,

std::string CreateReturnJson(const std::string& id, const std::string& model,
const std::string& content,
Json::Value finish_reason = Json::Value()) {
Json::Value finish_reason, bool include_usage,
std::optional<Usage> usage = std::nullopt) {
Json::Value root;

root["id"] = id;
Expand All @@ -104,16 +111,34 @@ std::string CreateReturnJson(const std::string& id, const std::string& model,
root["object"] = "chat.completion.chunk";

Json::Value choicesArray(Json::arrayValue);
Json::Value choice;
// If usage, the choices field will always be an empty array
if (!usage) {
Json::Value choice;

choice["index"] = 0;
Json::Value delta;
delta["content"] = content;
choice["delta"] = delta;
choice["finish_reason"] = finish_reason;
choice["index"] = 0;
Json::Value delta;
delta["content"] = content;
choice["delta"] = delta;
choice["finish_reason"] = finish_reason;

choicesArray.append(choice);
choicesArray.append(choice);
}
root["choices"] = choicesArray;
if (include_usage) {
if (usage) {
Json::Value usage_json;
Json::Value details;
details["reasoning_tokens"] = 0;
usage_json["prompt_tokens"] = (*usage).prompt_tokens;
usage_json["completion_tokens"] = (*usage).completion_tokens;
usage_json["total_tokens"] =
(*usage).prompt_tokens + (*usage).completion_tokens;
usage_json["completion_tokens_details"] = details;
root["usage"] = usage_json;
} else {
root["usage"] = Json::Value();
}
}

Json::StreamWriterBuilder writer;
writer["indentation"] = ""; // This sets the indentation to an empty string,
Expand Down Expand Up @@ -400,6 +425,8 @@ void LlamaEngine::SetFileLogger(int max_log_lines,
}
},
nullptr);
freopen(log_path.c_str(), "w", stderr);
freopen(log_path.c_str(), "w", stdout);
}

bool LlamaEngine::LoadModelImpl(std::shared_ptr<Json::Value> json_body) {
Expand Down Expand Up @@ -755,6 +782,7 @@ void LlamaEngine::HandleInferenceImpl(
data["stop"] = stopWords;

bool is_streamed = data["stream"];
bool include_usage = completion.include_usage;
// Enable full message debugging
#ifdef DEBUG
LOG_INFO << "Request " << request_id << ": " << "Current completion text";
Expand All @@ -768,7 +796,7 @@ void LlamaEngine::HandleInferenceImpl(

// Queued task
si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id,
n_probs]() {
n_probs, include_usage]() {
state->task_id = state->llama.RequestCompletion(data, false, false, -1);
while (state->llama.model_loaded_external) {
TaskResult result = state->llama.NextResult(state->task_id);
Expand All @@ -787,7 +815,7 @@ void LlamaEngine::HandleInferenceImpl(
const std::string str =
"data: " +
CreateReturnJson(llama_utils::generate_random_string(20), "_",
to_send) +
to_send, "", include_usage, std::nullopt) +
"\n\n";
Json::Value respData;
respData["data"] = str;
Expand All @@ -801,11 +829,21 @@ void LlamaEngine::HandleInferenceImpl(
if (result.stop) {
LOG_INFO << "Request " << request_id << ": " << "End of result";
state->llama.RequestCancel(state->task_id);
// include_usage
// If set, an additional chunk will be streamed before the data: [DONE] message.
// The usage field on this chunk shows the token usage statistics for the entire request,
// and the choices field will always be an empty array.
// All other chunks will also include a usage field, but with a null value.
Json::Value respData;
std::optional<Usage> u;
if (include_usage) {
u = Usage{result.result_json["tokens_evaluated"],
result.result_json["tokens_predicted"]};
}
const std::string str =
"data: " +
CreateReturnJson(llama_utils::generate_random_string(20), "_",
"", "stop") +
"", "stop", include_usage, u) +
"\n\n" + "data: [DONE]" + "\n\n";
respData["data"] = str;
Json::Value status;
Expand Down

0 comments on commit a990689

Please sign in to comment.