From 74d2d67aefa34684920c993cc877b989a91f3da6 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Jul 2023 01:31:52 +0200 Subject: [PATCH] fix(llm): require errors to be Send+Sync --- crates/llm-base/src/inference_session.rs | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index a71a57fd..caff5e67 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -280,7 +280,7 @@ impl InferenceSession { } /// Feed a prompt to the model for this session. - pub fn feed_prompt<'a, E: std::error::Error + 'static, P: Into>>( + pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into>>( &mut self, model: &dyn Model, params: &InferenceParameters, @@ -407,7 +407,7 @@ impl InferenceSession { /// generated (specified by [InferenceRequest::maximum_token_count]). /// /// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token]. - pub fn infer( + pub fn infer( &mut self, model: &dyn Model, rng: &mut impl rand::Rng, @@ -440,13 +440,13 @@ impl InferenceSession { // Feed the initial prompt through the transformer, to update its // context window with new data, if necessary. if !request.prompt.is_empty() { - self.feed_prompt( - model, - parameters, - request.prompt, - output_request, - feed_prompt_callback(&mut callback), - )?; + self.feed_prompt( + model, + parameters, + request.prompt, + output_request, + feed_prompt_callback(&mut callback), + )?; } stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; @@ -663,7 +663,7 @@ pub enum InferenceError { EndOfText, #[error("the user-specified callback returned an error")] /// The user-specified callback returned an error. - UserCallback(Box), + UserCallback(Box), } #[derive(Error, Debug)] @@ -885,7 +885,7 @@ pub enum InferenceFeedback { /// Adapt an [InferenceResponse] callback so that it can be used in a call to /// [InferenceSession::feed_prompt]. -pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( +pub fn feed_prompt_callback<'a, E: std::error::Error + Send + Sync + 'static>( mut callback: impl FnMut(InferenceResponse) -> Result + 'a, ) -> impl FnMut(&[u8]) -> Result + 'a { let mut buffer = TokenUtf8Buffer::new(); @@ -897,8 +897,8 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( /// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated. /// This callback is used in [InferenceSession::infer] in chat_mode. -pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( - stop_sequence: String, +pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>( + stop_sequence: &'a str, mut callback: impl FnMut(String) + 'a, ) -> impl FnMut(InferenceResponse) -> Result + 'a { let mut stop_sequence_buf = String::new(); @@ -908,7 +908,7 @@ pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( let mut buf = stop_sequence_buf.clone(); buf.push_str(&token); - if buf.starts_with(&stop_sequence) { + if buf.starts_with(stop_sequence) { // We've generated the stop sequence, so we're done. // Note that this will contain the extra tokens that were generated after the stop sequence, // which may affect generation. This is non-ideal, but it's the best we can do without