Skip to content

Commit

Permalink
fix: finish_reason for non stream completion (#16)
Browse files Browse the repository at this point in the history
* fix: finish_reason for non stream completion

* fix: disable llama log at start

---------

Co-authored-by: vansangpfiev <[email protected]>
  • Loading branch information
vansangpfiev and sangjanai authored May 8, 2024
1 parent 3e7fd66 commit 7316198
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 46 deletions.
100 changes: 56 additions & 44 deletions src/LlamaEngine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ std::string create_return_json(const std::string& id, const std::string& model,
}
} // namespace

LlamaEngine::LlamaEngine() {
log_disable();
}

LlamaEngine::~LlamaEngine() {
StopBackgroundTask();
}
Expand Down Expand Up @@ -538,7 +542,8 @@ void LlamaEngine::HandleInferenceImpl(
cb(std::move(status), std::move(respData));

if (result.stop) {
LOG_INFO << "Request " << request_id << ": " << "End of result";
LOG_INFO << "Request " << request_id << ": "
<< "End of result";
state->llama.request_cancel(state->task_id);
Json::Value respData;
const std::string str =
Expand Down Expand Up @@ -574,56 +579,63 @@ void LlamaEngine::HandleInferenceImpl(
LOG_INFO << "Request " << request_id << ": "
<< "Task completed, release it";
// Request completed, release it
if(!state->llama.model_loaded_external) {
if (!state->llama.model_loaded_external) {
LOG_WARN << "Model unloaded during inference";
Json::Value respData;
respData["data"] = std::string();
Json::Value status;
status["is_done"] = false;
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(respData));
respData["data"] = std::string();
Json::Value status;
status["is_done"] = false;
status["has_error"] = true;
status["is_stream"] = true;
status["status_code"] = k200OK;
cb(std::move(status), std::move(respData));
}
LOG_INFO << "Request " << request_id << ": "
<< "Inference completed";
});
} else {
queue_->runTaskInQueue(
[this, request_id, cb = std::move(callback), d = std::move(data)]() {
Json::Value respData;
int task_id = llama_.request_completion(d, false, false, -1);
LOG_INFO << "Request " << request_id << ": "
<< "Non stream, waiting for respone";
if (!json_value(d, "stream", false)) {
bool has_error = false;
std::string completion_text;
task_result result = llama_.next_result(task_id);
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
std::string to_send = result.result_json["content"];
llama_utils::ltrim(to_send);
respData = create_full_return_json(
llama_utils::generate_random_string(20), "_", to_send, "_",
prompt_tokens, predicted_tokens);
} else {
bool has_error = true;
respData["message"] = "Internal error during inference";
LOG_ERROR << "Request " << request_id << ": "
<< "Error during inference";
}
Json::Value status;
status["is_done"] = true;
status["has_error"] = has_error;
status["is_stream"] = false;
status["status_code"] = k200OK;
cb(std::move(status), std::move(respData));

LOG_INFO << "Request " << request_id << ": "
<< "Inference completed";
}
});
queue_->runTaskInQueue([this, request_id, cb = std::move(callback),
d = std::move(data)]() {
Json::Value respData;
int task_id = llama_.request_completion(d, false, false, -1);
LOG_INFO << "Request " << request_id << ": "
<< "Non stream, waiting for respone";
if (!json_value(d, "stream", false)) {
bool has_error = false;
std::string completion_text;
task_result result = llama_.next_result(task_id);
if (!result.error && result.stop) {
int prompt_tokens = result.result_json["tokens_evaluated"];
int predicted_tokens = result.result_json["tokens_predicted"];
std::string to_send = result.result_json["content"];
llama_utils::ltrim(to_send);
//https://platform.openai.com/docs/api-reference/chat/object
// finish_reason string
// The reason the model stopped generating tokens. This will be `stop`
// if the model hit a natural stop point or a provided stop sequence,
// `length` if the maximum number of tokens specified in the request was reached,
// `content_filter` if content was omitted due to a flag from our content filters,
// `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.
respData = create_full_return_json(
llama_utils::generate_random_string(20), "_", to_send, "_",
prompt_tokens, predicted_tokens, "stop" /*finish_reason*/);
} else {
bool has_error = true;
respData["message"] = "Internal error during inference";
LOG_ERROR << "Request " << request_id << ": "
<< "Error during inference";
}
Json::Value status;
status["is_done"] = true;
status["has_error"] = has_error;
status["is_stream"] = false;
status["status_code"] = k200OK;
cb(std::move(status), std::move(respData));

LOG_INFO << "Request " << request_id << ": "
<< "Inference completed";
}
});
}
}

Expand Down
4 changes: 2 additions & 2 deletions src/LlamaEngine.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class LlamaEngine : public EngineI {
public:
LlamaEngine();
~LlamaEngine() final;
// #### Interface ####
void HandleChatCompletion(
Expand Down Expand Up @@ -38,8 +39,7 @@ class LlamaEngine : public EngineI {
void HandleBackgroundTask();
void StopBackgroundTask();

//TODO(sang) public for now, should move all variables to private section later
public:
private:
llama_server_context llama_;
std::unique_ptr<trantor::ConcurrentTaskQueue> queue_;
std::thread bgr_thread_;
Expand Down

0 comments on commit 7316198

Please sign in to comment.