-
Notifications
You must be signed in to change notification settings - Fork 362
allow chat to halt new token generation on stop_sequence
#364
Changes from 12 commits
43b4f0b
f87b4fa
38d8632
41bf37a
39e45db
acbf117
b4efde7
a0ad8b4
138263a
8702593
710f3c2
74d2d67
34a8c68
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -280,7 +280,7 @@ impl InferenceSession { | |
} | ||
|
||
/// Feed a prompt to the model for this session. | ||
pub fn feed_prompt<'a, E: std::error::Error + 'static, P: Into<Prompt<'a>>>( | ||
pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into<Prompt<'a>>>( | ||
&mut self, | ||
model: &dyn Model, | ||
params: &InferenceParameters, | ||
|
@@ -407,7 +407,7 @@ impl InferenceSession { | |
/// generated (specified by [InferenceRequest::maximum_token_count]). | ||
/// | ||
/// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token]. | ||
pub fn infer<E: std::error::Error + 'static>( | ||
pub fn infer<E: std::error::Error + Send + Sync + 'static>( | ||
&mut self, | ||
model: &dyn Model, | ||
rng: &mut impl rand::Rng, | ||
|
@@ -438,14 +438,16 @@ impl InferenceSession { | |
let parameters = request.parameters; | ||
|
||
// Feed the initial prompt through the transformer, to update its | ||
// context window with new data. | ||
self.feed_prompt( | ||
model, | ||
parameters, | ||
request.prompt, | ||
output_request, | ||
feed_prompt_callback(&mut callback), | ||
)?; | ||
// context window with new data, if necessary. | ||
if !request.prompt.is_empty() { | ||
self.feed_prompt( | ||
model, | ||
parameters, | ||
request.prompt, | ||
output_request, | ||
feed_prompt_callback(&mut callback), | ||
)?; | ||
} | ||
stats.feed_prompt_duration = start_at.elapsed().unwrap(); | ||
stats.prompt_tokens = self.n_past; | ||
|
||
|
@@ -661,7 +663,7 @@ pub enum InferenceError { | |
EndOfText, | ||
#[error("the user-specified callback returned an error")] | ||
/// The user-specified callback returned an error. | ||
UserCallback(Box<dyn std::error::Error>), | ||
UserCallback(Box<dyn std::error::Error + Send + Sync>), | ||
} | ||
|
||
#[derive(Error, Debug)] | ||
|
@@ -883,7 +885,7 @@ pub enum InferenceFeedback { | |
|
||
/// Adapt an [InferenceResponse] callback so that it can be used in a call to | ||
/// [InferenceSession::feed_prompt]. | ||
pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( | ||
pub fn feed_prompt_callback<'a, E: std::error::Error + Send + Sync + 'static>( | ||
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a, | ||
) -> impl FnMut(&[u8]) -> Result<InferenceFeedback, E> + 'a { | ||
let mut buffer = TokenUtf8Buffer::new(); | ||
|
@@ -892,3 +894,40 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( | |
None => Ok(InferenceFeedback::Continue), | ||
} | ||
} | ||
|
||
/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated. | ||
/// This callback is used in [InferenceSession::infer] in chat_mode. | ||
pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what do these do? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In Rust, objects get the I needed to add this because |
||
stop_sequence: &'a str, | ||
mut callback: impl FnMut(String) + 'a, | ||
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a { | ||
let mut stop_sequence_buf = String::new(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. interesting, scoping this buffer to the function seems a lot better! |
||
move |resp| match resp { | ||
InferenceResponse::InferredToken(token) => { | ||
// We've generated a token, so we need to check if it's contained in the stop sequence. | ||
let mut buf = stop_sequence_buf.clone(); | ||
buf.push_str(&token); | ||
|
||
if buf.starts_with(stop_sequence) { | ||
// We've generated the stop sequence, so we're done. | ||
// Note that this will contain the extra tokens that were generated after the stop sequence, | ||
// which may affect generation. This is non-ideal, but it's the best we can do without | ||
// modifying the model. | ||
stop_sequence_buf.clear(); | ||
return Ok(InferenceFeedback::Halt); | ||
} else if stop_sequence.starts_with(&buf) { | ||
// We've generated a prefix of the stop sequence, so we need to keep buffering. | ||
stop_sequence_buf = buf; | ||
return Ok(InferenceFeedback::Continue); | ||
} | ||
|
||
// We've generated a token that isn't part of the stop sequence, so we can | ||
// pass it to the callback. | ||
stop_sequence_buf.clear(); | ||
callback(buf); | ||
Ok(InferenceFeedback::Continue) | ||
} | ||
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), | ||
_ => Ok(InferenceFeedback::Continue), | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wow very cool!