From e843d7df0e8b177ab122a9f7bfa7196274ccd204 Mon Sep 17 00:00:00 2001 From: Ettore Di Giacinto Date: Mon, 15 Apr 2024 19:47:11 +0200 Subject: [PATCH] feat(grpc): return consumed token count and update response accordingly (#2035) Fixes: #1920 --- backend/backend.proto | 2 ++ backend/cpp/llama/grpc-server.cpp | 8 ++++++++ core/backend/llm.go | 6 ++++++ core/services/openai.go | 8 ++++---- 4 files changed, 20 insertions(+), 4 deletions(-) diff --git a/backend/backend.proto b/backend/backend.proto index 56d919efd3b0..62e1a1a64448 100644 --- a/backend/backend.proto +++ b/backend/backend.proto @@ -114,6 +114,8 @@ message PredictOptions { // The response message containing the result message Reply { bytes message = 1; + int32 tokens = 2; + int32 prompt_tokens = 3; } message ModelOptions { diff --git a/backend/cpp/llama/grpc-server.cpp b/backend/cpp/llama/grpc-server.cpp index a2e39a9c5f65..6fb086585f4e 100644 --- a/backend/cpp/llama/grpc-server.cpp +++ b/backend/cpp/llama/grpc-server.cpp @@ -2332,6 +2332,10 @@ class BackendServiceImpl final : public backend::Backend::Service { std::string completion_text = result.result_json.value("content", ""); reply.set_message(completion_text); + int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0); + reply.set_tokens(tokens_predicted); + int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); + reply.set_prompt_tokens(tokens_evaluated); // Send the reply writer->Write(reply); @@ -2357,6 +2361,10 @@ class BackendServiceImpl final : public backend::Backend::Service { task_result result = llama.queue_results.recv(task_id); if (!result.error && result.stop) { completion_text = result.result_json.value("content", ""); + int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0); + int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0); + reply->set_prompt_tokens(tokens_evaluated); + reply->set_tokens(tokens_predicted); reply->set_message(completion_text); } else diff --git a/core/backend/llm.go b/core/backend/llm.go index 1878e87af6b2..75766d78d437 100644 --- a/core/backend/llm.go +++ b/core/backend/llm.go @@ -189,6 +189,12 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest, } else { go func() { reply, err := inferenceModel.Predict(ctx, grpcPredOpts) + if tokenUsage.Prompt == 0 { + tokenUsage.Prompt = int(reply.PromptTokens) + } + if tokenUsage.Completion == 0 { + tokenUsage.Completion = int(reply.Tokens) + } if err != nil { rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err} close(rawResultChannel) diff --git a/core/services/openai.go b/core/services/openai.go index 0f61d6f42da9..3fa041f5ee96 100644 --- a/core/services/openai.go +++ b/core/services/openai.go @@ -160,7 +160,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest bc, request, err := oais.getConfig(request) if err != nil { - log.Error().Msgf("[oais::GenerateTextFromRequest] error getting configuration: %q", err) + log.Error().Err(err).Msgf("[oais::GenerateTextFromRequest] error getting configuration") return } @@ -259,7 +259,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest // If any of the setup goroutines experienced an error, quit early here. if setupError != nil { go func() { - log.Error().Msgf("[OAIS GenerateTextFromRequest] caught an error during setup: %q", setupError) + log.Error().Err(setupError).Msgf("[OAIS GenerateTextFromRequest] caught an error during setup") rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError} close(rawFinalResultChannel) }() @@ -603,7 +603,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche Usage: schema.OpenAIUsage{ PromptTokens: rawResult.Value.Usage.Prompt, CompletionTokens: rawResult.Value.Usage.Completion, - TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt, + TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, }, } @@ -644,7 +644,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche Usage: schema.OpenAIUsage{ PromptTokens: rawResult.Value.Usage.Prompt, CompletionTokens: rawResult.Value.Usage.Completion, - TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt, + TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion, }, }