Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
fix(llm): require errors to be Send+Sync
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 12, 2023
1 parent 710f3c2 commit 74d2d67
Showing 1 changed file with 14 additions and 14 deletions.
28 changes: 14 additions & 14 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl InferenceSession {
}

/// Feed a prompt to the model for this session.
pub fn feed_prompt<'a, E: std::error::Error + 'static, P: Into<Prompt<'a>>>(
pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into<Prompt<'a>>>(
&mut self,
model: &dyn Model,
params: &InferenceParameters,
Expand Down Expand Up @@ -407,7 +407,7 @@ impl InferenceSession {
/// generated (specified by [InferenceRequest::maximum_token_count]).
///
/// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token].
pub fn infer<E: std::error::Error + 'static>(
pub fn infer<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
Expand Down Expand Up @@ -440,13 +440,13 @@ impl InferenceSession {
// Feed the initial prompt through the transformer, to update its
// context window with new data, if necessary.
if !request.prompt.is_empty() {
self.feed_prompt(
model,
parameters,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
self.feed_prompt(
model,
parameters,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
}
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;
Expand Down Expand Up @@ -663,7 +663,7 @@ pub enum InferenceError {
EndOfText,
#[error("the user-specified callback returned an error")]
/// The user-specified callback returned an error.
UserCallback(Box<dyn std::error::Error>),
UserCallback(Box<dyn std::error::Error + Send + Sync>),
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -885,7 +885,7 @@ pub enum InferenceFeedback {

/// Adapt an [InferenceResponse] callback so that it can be used in a call to
/// [InferenceSession::feed_prompt].
pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
pub fn feed_prompt_callback<'a, E: std::error::Error + Send + Sync + 'static>(
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a,
) -> impl FnMut(&[u8]) -> Result<InferenceFeedback, E> + 'a {
let mut buffer = TokenUtf8Buffer::new();
Expand All @@ -897,8 +897,8 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(

/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated.
/// This callback is used in [InferenceSession::infer] in chat_mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
stop_sequence: String,
pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>(
stop_sequence: &'a str,
mut callback: impl FnMut(String) + 'a,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
let mut stop_sequence_buf = String::new();
Expand All @@ -908,7 +908,7 @@ pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
let mut buf = stop_sequence_buf.clone();
buf.push_str(&token);

if buf.starts_with(&stop_sequence) {
if buf.starts_with(stop_sequence) {
// We've generated the stop sequence, so we're done.
// Note that this will contain the extra tokens that were generated after the stop sequence,
// which may affect generation. This is non-ideal, but it's the best we can do without
Expand Down

4 comments on commit 74d2d67

@pixelspark
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Brilliant, this was biting me the other day 👍

@philpax
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah it bit me a few times - I was trying to avoid it because I didn't want to require all errors to be Send+Sync, but I didn't want to deal with the headache of trying to genericise the error over Send/Sync. We'll see if anyone complains 😅

@pixelspark
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Well I think it is not uncommon for error types to be Send+Sync anyway. They should be lightweight objects anyway, either simple enums/constant strings, or allocated anew for each error (with specific error information). Send is easy to achieve and Sync is no issue for an object that isn't really shared anyway.

@philpax
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, the only concern I have is single-threaded environments where the traits aren't relevant/can't be implemented (i.e. WASM). That's a bridge we can cross when we get to it.

Please sign in to comment.