diff --git a/src/llama_server_context.cc b/src/llama_server_context.cc index df16551b..90b64e54 100644 --- a/src/llama_server_context.cc +++ b/src/llama_server_context.cc @@ -250,7 +250,6 @@ json LlamaServerContext::GetModelProps() { int LlamaServerContext::RequestCompletion(json data, bool infill, bool embedding, int multitask_id) { - std::unique_lock lock(mutex_tasks); TaskServer task; task.id = id_gen++; task.target_id = 0; @@ -263,12 +262,14 @@ int LlamaServerContext::RequestCompletion(json data, bool infill, // when a completion task's prompt array is not a singleton, we split it // into multiple requests if (task.data.at("prompt").size() > 1) { - lock.unlock(); // entering new func scope return SplitMultipromptTask(task); } // otherwise, it's a single-prompt task, we actually queue it - queue_tasks.push_back(task); + { + std::lock_guard lock(mutex_tasks); + queue_tasks.push_back(task); + } condition_tasks.notify_one(); return task.id; } @@ -303,12 +304,14 @@ TaskResult LlamaServerContext::NextResult(int task_id) { } void LlamaServerContext::RequestCancel(int task_id) { - std::unique_lock lock(mutex_tasks); TaskServer task; task.id = id_gen++; task.type = TaskType::kCancelTask; task.target_id = task_id; - queue_tasks.push_back(task); + { + std::lock_guard lock(mutex_tasks); + queue_tasks.push_back(task); + } condition_tasks.notify_one(); } @@ -820,13 +823,15 @@ void LlamaServerContext::SendError(int id_task, int id_multi, } void LlamaServerContext::AddMultiTask(int id, std::vector& sub_ids) { - std::lock_guard lock(mutex_tasks); TaskMulti multi; multi.id = id; std::copy( sub_ids.begin(), sub_ids.end(), std::inserter(multi.subtasks_remaining, multi.subtasks_remaining.end())); - queue_multitasks.push_back(multi); + { + std::lock_guard lock(mutex_tasks); + queue_multitasks.push_back(multi); + } condition_tasks.notify_one(); } @@ -880,7 +885,6 @@ json LlamaServerContext::GetFormatedGeneration(LlamaClientSlot& slot) { void LlamaServerContext::SendPartialResponse(LlamaClientSlot& slot, CompletionTokenOutput tkn) { - std::unique_lock lock(mutex_results); TaskResult res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -916,12 +920,14 @@ void LlamaServerContext::SendPartialResponse(LlamaClientSlot& slot, res.result_json["model"] = slot.oaicompat_model; } - queue_results.push_back(res); + { + std::lock_guard lock(mutex_results); + queue_results.push_back(res); + } condition_results.notify_all(); } void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) { - std::unique_lock lock(mutex_results); TaskResult res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -972,12 +978,14 @@ void LlamaServerContext::SendFinalResponse(LlamaClientSlot& slot) { UpdateMultiTask(slot.multitask_id, slot.task_id, res); } - queue_results.push_back(res); + { + std::lock_guard lock(mutex_results); + queue_results.push_back(res); + } condition_results.notify_all(); } void LlamaServerContext::SendEmbedding(LlamaClientSlot& slot) { - std::unique_lock lock(mutex_results); TaskResult res; res.id = slot.task_id; res.multitask_id = slot.multitask_id; @@ -1015,7 +1023,10 @@ void LlamaServerContext::SendEmbedding(LlamaClientSlot& slot) { {"embedding", embd_res}, }; - queue_results.push_back(res); + { + std::lock_guard lock(mutex_results); + queue_results.push_back(res); + } condition_results.notify_all(); } @@ -1111,10 +1122,15 @@ int LlamaServerContext::SplitMultipromptTask(TaskServer& multiprompt_task) { } void LlamaServerContext::ProcessTasks() { - std::unique_lock lock(mutex_tasks); - while (!queue_tasks.empty()) { + while (true) { + std::unique_lock l(mutex_tasks); + if (queue_tasks.empty()) { + l.unlock(); + break; + } TaskServer task = queue_tasks.front(); queue_tasks.erase(queue_tasks.begin()); + l.unlock(); switch (task.type) { case TaskType::kCompletionTask: { LlamaClientSlot* slot = GetSlot(json_value(task.data, "slot_id", -1)); @@ -1155,6 +1171,7 @@ void LlamaServerContext::ProcessTasks() { // remove finished multitasks from the queue of multitasks, and add the // corresponding result to the result queue + std::lock_guard l(mutex_tasks); auto queue_iterator = queue_multitasks.begin(); while (queue_iterator != queue_multitasks.end()) { if (queue_iterator->subtasks_remaining.empty()) { @@ -1172,8 +1189,10 @@ void LlamaServerContext::ProcessTasks() { } aggregate_result.result_json = json{"results", result_jsons}; - std::lock_guard lock(mutex_results); - queue_results.push_back(aggregate_result); + { + std::lock_guard lock(mutex_results); + queue_results.push_back(aggregate_result); + } condition_results.notify_all(); queue_iterator = queue_multitasks.erase(queue_iterator); @@ -1211,8 +1230,6 @@ bool LlamaServerContext::UpdateSlots() { "cache"; KvCacheClear(); } - // std::this_thread::sleep_for(std::chrono::milliseconds(5)); - // TODO: Need to implement queueing using CV for better performance std::unique_lock lock(mutex_tasks); condition_tasks.wait(lock, [&] { return (!queue_tasks.empty() && model_loaded_external) || diff --git a/src/llama_server_context.h b/src/llama_server_context.h index 6fd12095..c71da0d2 100644 --- a/src/llama_server_context.h +++ b/src/llama_server_context.h @@ -117,7 +117,7 @@ struct LlamaServerContext { bool all_slots_are_idle = false; bool add_bos_token = true; - int32_t id_gen; + std::atomic id_gen; int32_t n_ctx; // total context for all clients / slots // Internal @@ -138,7 +138,7 @@ struct LlamaServerContext { std::vector queue_tasks; std::vector queue_results; std::vector queue_multitasks; - std::mutex mutex_tasks; // also guards id_gen, and queue_multitasks + std::mutex mutex_tasks; // also guards queue_multitasks std::condition_variable condition_tasks; std::mutex mutex_results; std::condition_variable condition_results;