diff --git a/src/chat_completion_request.h b/src/chat_completion_request.h index 9ffc7d6..3b19115 100644 --- a/src/chat_completion_request.h +++ b/src/chat_completion_request.h @@ -57,6 +57,7 @@ struct ChatCompletionRequest { bool ignore_eos = false; int n_probs = 0; int min_keep = 0; + bool include_usage = false; std::string grammar; Json::Value logit_bias = Json::Value(Json::arrayValue); @@ -80,6 +81,12 @@ inline ChatCompletionRequest fromJson(std::shared_ptr jsonBody) { common_sampler_params default_params; if (jsonBody) { completion.stream = (*jsonBody).get("stream", false).asBool(); + if(completion.stream) { + auto& stream_options = (*jsonBody)["stream_options"]; + if(!stream_options.isNull()) { + completion.include_usage = stream_options.get("include_usage", false).asBool(); + } + } completion.max_tokens = (*jsonBody).get("max_tokens", 500).asInt(); completion.top_p = (*jsonBody).get("top_p", 0.95).asFloat(); completion.temperature = (*jsonBody).get("temperature", 0.8).asFloat(); diff --git a/src/llama_engine.cc b/src/llama_engine.cc index 197f994..5a4700c 100644 --- a/src/llama_engine.cc +++ b/src/llama_engine.cc @@ -1,6 +1,7 @@ #include "llama_engine.h" #include +#include #include "json/writer.h" #include "llama_utils.h" #include "trantor/utils/Logger.h" @@ -32,6 +33,11 @@ struct InferenceState { InferenceState(LlamaServerContext& l) : llama(l) {} }; +struct Usage { + int prompt_tokens = 0; + int completion_tokens = 0; +}; + /** * This function is to create the smart pointer to InferenceState, hence the * InferenceState will be persisting even tho the lambda in streaming might go @@ -95,7 +101,8 @@ Json::Value CreateFullReturnJson(const std::string& id, std::string CreateReturnJson(const std::string& id, const std::string& model, const std::string& content, - Json::Value finish_reason = Json::Value()) { + Json::Value finish_reason, bool include_usage, + std::optional usage = std::nullopt) { Json::Value root; root["id"] = id; @@ -104,16 +111,34 @@ std::string CreateReturnJson(const std::string& id, const std::string& model, root["object"] = "chat.completion.chunk"; Json::Value choicesArray(Json::arrayValue); - Json::Value choice; + // If usage, the choices field will always be an empty array + if (!usage) { + Json::Value choice; - choice["index"] = 0; - Json::Value delta; - delta["content"] = content; - choice["delta"] = delta; - choice["finish_reason"] = finish_reason; + choice["index"] = 0; + Json::Value delta; + delta["content"] = content; + choice["delta"] = delta; + choice["finish_reason"] = finish_reason; - choicesArray.append(choice); + choicesArray.append(choice); + } root["choices"] = choicesArray; + if (include_usage) { + if (usage) { + Json::Value usage_json; + Json::Value details; + details["reasoning_tokens"] = 0; + usage_json["prompt_tokens"] = (*usage).prompt_tokens; + usage_json["completion_tokens"] = (*usage).completion_tokens; + usage_json["total_tokens"] = + (*usage).prompt_tokens + (*usage).completion_tokens; + usage_json["completion_tokens_details"] = details; + root["usage"] = usage_json; + } else { + root["usage"] = Json::Value(); + } + } Json::StreamWriterBuilder writer; writer["indentation"] = ""; // This sets the indentation to an empty string, @@ -400,6 +425,8 @@ void LlamaEngine::SetFileLogger(int max_log_lines, } }, nullptr); + freopen(log_path.c_str(), "w", stderr); + freopen(log_path.c_str(), "w", stdout); } bool LlamaEngine::LoadModelImpl(std::shared_ptr json_body) { @@ -755,6 +782,7 @@ void LlamaEngine::HandleInferenceImpl( data["stop"] = stopWords; bool is_streamed = data["stream"]; + bool include_usage = completion.include_usage; // Enable full message debugging #ifdef DEBUG LOG_INFO << "Request " << request_id << ": " << "Current completion text"; @@ -768,7 +796,7 @@ void LlamaEngine::HandleInferenceImpl( // Queued task si.q->runTaskInQueue([cb = std::move(callback), state, data, request_id, - n_probs]() { + n_probs, include_usage]() { state->task_id = state->llama.RequestCompletion(data, false, false, -1); while (state->llama.model_loaded_external) { TaskResult result = state->llama.NextResult(state->task_id); @@ -787,7 +815,7 @@ void LlamaEngine::HandleInferenceImpl( const std::string str = "data: " + CreateReturnJson(llama_utils::generate_random_string(20), "_", - to_send) + + to_send, "", include_usage, std::nullopt) + "\n\n"; Json::Value respData; respData["data"] = str; @@ -801,11 +829,21 @@ void LlamaEngine::HandleInferenceImpl( if (result.stop) { LOG_INFO << "Request " << request_id << ": " << "End of result"; state->llama.RequestCancel(state->task_id); + // include_usage + // If set, an additional chunk will be streamed before the data: [DONE] message. + // The usage field on this chunk shows the token usage statistics for the entire request, + // and the choices field will always be an empty array. + // All other chunks will also include a usage field, but with a null value. Json::Value respData; + std::optional u; + if (include_usage) { + u = Usage{result.result_json["tokens_evaluated"], + result.result_json["tokens_predicted"]}; + } const std::string str = "data: " + CreateReturnJson(llama_utils::generate_random_string(20), "_", - "", "stop") + + "", "stop", include_usage, u) + "\n\n" + "data: [DONE]" + "\n\n"; respData["data"] = str; Json::Value status;