Skip to content

Commit

Permalink
feat(grpc): return consumed token count and update response according…
Browse files Browse the repository at this point in the history
…ly (mudler#2035)

Fixes: mudler#1920
  • Loading branch information
mudler authored Apr 15, 2024
1 parent de3a1a0 commit e843d7d
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 4 deletions.
2 changes: 2 additions & 0 deletions backend/backend.proto
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,8 @@ message PredictOptions {
// The response message containing the result
message Reply {
bytes message = 1;
int32 tokens = 2;
int32 prompt_tokens = 3;
}

message ModelOptions {
Expand Down
8 changes: 8 additions & 0 deletions backend/cpp/llama/grpc-server.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2332,6 +2332,10 @@ class BackendServiceImpl final : public backend::Backend::Service {
std::string completion_text = result.result_json.value("content", "");

reply.set_message(completion_text);
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
reply.set_tokens(tokens_predicted);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply.set_prompt_tokens(tokens_evaluated);

// Send the reply
writer->Write(reply);
Expand All @@ -2357,6 +2361,10 @@ class BackendServiceImpl final : public backend::Backend::Service {
task_result result = llama.queue_results.recv(task_id);
if (!result.error && result.stop) {
completion_text = result.result_json.value("content", "");
int32_t tokens_predicted = result.result_json.value("tokens_predicted", 0);
int32_t tokens_evaluated = result.result_json.value("tokens_evaluated", 0);
reply->set_prompt_tokens(tokens_evaluated);
reply->set_tokens(tokens_predicted);
reply->set_message(completion_text);
}
else
Expand Down
6 changes: 6 additions & 0 deletions core/backend/llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,12 @@ func (llmbs *LLMBackendService) Inference(ctx context.Context, req *LLMRequest,
} else {
go func() {
reply, err := inferenceModel.Predict(ctx, grpcPredOpts)
if tokenUsage.Prompt == 0 {
tokenUsage.Prompt = int(reply.PromptTokens)
}
if tokenUsage.Completion == 0 {
tokenUsage.Completion = int(reply.Tokens)
}
if err != nil {
rawResultChannel <- concurrency.ErrorOr[*LLMResponse]{Error: err}
close(rawResultChannel)
Expand Down
8 changes: 4 additions & 4 deletions core/services/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest

bc, request, err := oais.getConfig(request)
if err != nil {
log.Error().Msgf("[oais::GenerateTextFromRequest] error getting configuration: %q", err)
log.Error().Err(err).Msgf("[oais::GenerateTextFromRequest] error getting configuration")
return
}

Expand Down Expand Up @@ -259,7 +259,7 @@ func (oais *OpenAIService) GenerateTextFromRequest(request *schema.OpenAIRequest
// If any of the setup goroutines experienced an error, quit early here.
if setupError != nil {
go func() {
log.Error().Msgf("[OAIS GenerateTextFromRequest] caught an error during setup: %q", setupError)
log.Error().Err(setupError).Msgf("[OAIS GenerateTextFromRequest] caught an error during setup")
rawFinalResultChannel <- concurrency.ErrorOr[*schema.OpenAIResponse]{Error: setupError}
close(rawFinalResultChannel)
}()
Expand Down Expand Up @@ -603,7 +603,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche
Usage: schema.OpenAIUsage{
PromptTokens: rawResult.Value.Usage.Prompt,
CompletionTokens: rawResult.Value.Usage.Completion,
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion,
},
}

Expand Down Expand Up @@ -644,7 +644,7 @@ func (oais *OpenAIService) GenerateFromMultipleMessagesChatRequest(request *sche
Usage: schema.OpenAIUsage{
PromptTokens: rawResult.Value.Usage.Prompt,
CompletionTokens: rawResult.Value.Usage.Completion,
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Prompt,
TotalTokens: rawResult.Value.Usage.Prompt + rawResult.Value.Usage.Completion,
},
}

Expand Down

0 comments on commit e843d7d

Please sign in to comment.