Skip to content

Commit

Permalink
More elegant way for handing non-streaming finish signal.
Browse files Browse the repository at this point in the history
  • Loading branch information
guoqingbao committed Jul 24, 2024
1 parent 4d9c864 commit b9db828
Show file tree
Hide file tree
Showing 5 changed files with 20 additions and 3 deletions.
4 changes: 3 additions & 1 deletion src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,21 +137,23 @@ async fn main() -> Result<(), APIError> {
dtype: config.kv_cache_dtype,
};
println!("Cache config {:?}", cache_config);

let finish_notify = Arc::new(Notify::new());
let llm_engine = LLMEngine::new(
model.0,
SchedulerConfig {
max_num_seqs: args.max_num_seqs,
},
cache_config,
Arc::new(Notify::new()),
finish_notify.clone(),
)?;

let server_data = OpenAIServerData {
pipeline_config: model.1,
model: llm_engine,
record_conversation: args.record_conversation,
device: Device::Cpu,
finish_notify: finish_notify.clone(),
};

println!("Server started at http://127.0.0.1:{}.", args.port);
Expand Down
3 changes: 2 additions & 1 deletion src/openai/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use candle_core::Device;
use std::sync::Arc;
use tokenizers::{EncodeInput, Encoding, Tokenizer};
use tokio::sync::Mutex;
use tokio::sync::{Mutex, Notify};

use self::{pipelines::llm_engine::LLMEngine, responses::APIError};

Expand Down Expand Up @@ -45,6 +45,7 @@ pub struct OpenAIServerData {
pub pipeline_config: PipelineConfig,
pub record_conversation: bool,
pub device: Device,
pub finish_notify: Arc<Notify>,
}

pub mod conversation;
Expand Down
9 changes: 8 additions & 1 deletion src/openai/openai_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,15 @@ pub async fn chat_completions(
)
} else {
// wait until current response finished
tokio::time::sleep(Duration::from_millis(100)).await; //permits generation thread to work
data.finish_notify.notified().await;
let model = data.model.lock().await;
if !model.completion_records.contains_key(&request_id) {
return ChatResponder::ModelError(APIError::from(format!(
"Unable to generate response for request {}",
request_id
)));
}

let choices = &model.completion_records[&request_id].0;
let usage = &model.completion_records[&request_id].1;

Expand Down
4 changes: 4 additions & 0 deletions src/openai/pipelines/llm_engine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pub struct LLMEngine {
cache_engine: CacheEngine,
sliding_window: Option<usize>,
pub notify: Arc<Notify>,
pub finish_notify: Arc<Notify>,
pub completion_records: HashMap<String, (Vec<ChatChoice>, ChatCompletionUsageResponse)>,
}

Expand All @@ -58,6 +59,7 @@ impl LLMEngine {
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
notify: Arc<Notify>,
finish_notify: Arc<Notify>,
) -> Result<Arc<Mutex<Self>>, APIError> {
let cache_engine = CacheEngine::new(
pipeline.get_model_config(),
Expand All @@ -76,6 +78,7 @@ impl LLMEngine {
cache_engine,
sliding_window,
notify: notify.clone(),
finish_notify: finish_notify.clone(),
completion_records: HashMap::new(),
}));
let engine_clone = engine.clone();
Expand Down Expand Up @@ -133,6 +136,7 @@ impl LLMEngine {
);
e.completion_records
.insert(request_id.clone(), (choices, usage));
finish_notify.notify_one();
}
});
});
Expand Down
3 changes: 3 additions & 0 deletions tests/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ async fn test_llama() -> Result<(), APIError> {
None,
)?;
let model = loader.load_model(paths, DType::F16, Device::Cpu)?;
let finish_notify = Arc::new(Notify::new());
let llm_engine = LLMEngine::new(
model.0,
SchedulerConfig { max_num_seqs: 256 },
Expand All @@ -45,13 +46,15 @@ async fn test_llama() -> Result<(), APIError> {
dtype: DType::F16,
},
Arc::new(Notify::new()),
finish_notify.clone(),
)?;

let server_data = OpenAIServerData {
pipeline_config: model.1,
model: llm_engine,
device: Device::Cpu,
record_conversation: false,
finish_notify: finish_notify.clone(),
};

let allow_origin = AllowOrigin::any();
Expand Down

0 comments on commit b9db828

Please sign in to comment.