From b9db8280f0ecec1934a0aa33304f5ccde6393a25 Mon Sep 17 00:00:00 2001 From: Guoqing Bao Date: Wed, 24 Jul 2024 17:28:22 +0800 Subject: [PATCH] More elegant way for handing non-streaming finish signal. --- src/main.rs | 4 +++- src/openai/mod.rs | 3 ++- src/openai/openai_server.rs | 9 ++++++++- src/openai/pipelines/llm_engine.rs | 4 ++++ tests/tests.rs | 3 +++ 5 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/main.rs b/src/main.rs index 2828022..1968be5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -137,7 +137,7 @@ async fn main() -> Result<(), APIError> { dtype: config.kv_cache_dtype, }; println!("Cache config {:?}", cache_config); - + let finish_notify = Arc::new(Notify::new()); let llm_engine = LLMEngine::new( model.0, SchedulerConfig { @@ -145,6 +145,7 @@ async fn main() -> Result<(), APIError> { }, cache_config, Arc::new(Notify::new()), + finish_notify.clone(), )?; let server_data = OpenAIServerData { @@ -152,6 +153,7 @@ async fn main() -> Result<(), APIError> { model: llm_engine, record_conversation: args.record_conversation, device: Device::Cpu, + finish_notify: finish_notify.clone(), }; println!("Server started at http://127.0.0.1:{}.", args.port); diff --git a/src/openai/mod.rs b/src/openai/mod.rs index 980f6e7..75d9764 100644 --- a/src/openai/mod.rs +++ b/src/openai/mod.rs @@ -1,7 +1,7 @@ use candle_core::Device; use std::sync::Arc; use tokenizers::{EncodeInput, Encoding, Tokenizer}; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, Notify}; use self::{pipelines::llm_engine::LLMEngine, responses::APIError}; @@ -45,6 +45,7 @@ pub struct OpenAIServerData { pub pipeline_config: PipelineConfig, pub record_conversation: bool, pub device: Device, + pub finish_notify: Arc, } pub mod conversation; diff --git a/src/openai/openai_server.rs b/src/openai/openai_server.rs index 25c6a67..b380d91 100644 --- a/src/openai/openai_server.rs +++ b/src/openai/openai_server.rs @@ -219,8 +219,15 @@ pub async fn chat_completions( ) } else { // wait until current response finished - tokio::time::sleep(Duration::from_millis(100)).await; //permits generation thread to work + data.finish_notify.notified().await; let model = data.model.lock().await; + if !model.completion_records.contains_key(&request_id) { + return ChatResponder::ModelError(APIError::from(format!( + "Unable to generate response for request {}", + request_id + ))); + } + let choices = &model.completion_records[&request_id].0; let usage = &model.completion_records[&request_id].1; diff --git a/src/openai/pipelines/llm_engine.rs b/src/openai/pipelines/llm_engine.rs index 6d957c9..9232fbb 100644 --- a/src/openai/pipelines/llm_engine.rs +++ b/src/openai/pipelines/llm_engine.rs @@ -49,6 +49,7 @@ pub struct LLMEngine { cache_engine: CacheEngine, sliding_window: Option, pub notify: Arc, + pub finish_notify: Arc, pub completion_records: HashMap, ChatCompletionUsageResponse)>, } @@ -58,6 +59,7 @@ impl LLMEngine { scheduler_config: SchedulerConfig, cache_config: CacheConfig, notify: Arc, + finish_notify: Arc, ) -> Result>, APIError> { let cache_engine = CacheEngine::new( pipeline.get_model_config(), @@ -76,6 +78,7 @@ impl LLMEngine { cache_engine, sliding_window, notify: notify.clone(), + finish_notify: finish_notify.clone(), completion_records: HashMap::new(), })); let engine_clone = engine.clone(); @@ -133,6 +136,7 @@ impl LLMEngine { ); e.completion_records .insert(request_id.clone(), (choices, usage)); + finish_notify.notify_one(); } }); }); diff --git a/tests/tests.rs b/tests/tests.rs index ded36a1..7a65c69 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -34,6 +34,7 @@ async fn test_llama() -> Result<(), APIError> { None, )?; let model = loader.load_model(paths, DType::F16, Device::Cpu)?; + let finish_notify = Arc::new(Notify::new()); let llm_engine = LLMEngine::new( model.0, SchedulerConfig { max_num_seqs: 256 }, @@ -45,6 +46,7 @@ async fn test_llama() -> Result<(), APIError> { dtype: DType::F16, }, Arc::new(Notify::new()), + finish_notify.clone(), )?; let server_data = OpenAIServerData { @@ -52,6 +54,7 @@ async fn test_llama() -> Result<(), APIError> { model: llm_engine, device: Device::Cpu, record_conversation: false, + finish_notify: finish_notify.clone(), }; let allow_origin = AllowOrigin::any();