diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 0e7520e7..b1e13d90 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -259,13 +259,6 @@ fn interactive( let parameters = generate.inference_parameters(model.eot_token_id()); let mut rng = generate.rng(); - let mut buf = String::new(); - - fn print_token(t: String) { - print!("{t}"); - std::io::stdout().flush().unwrap(); - } - fn session_ends_with_newline(session: &InferenceSession) -> bool { session .decoded_tokens() @@ -293,6 +286,11 @@ fn interactive( }; sp.clear(); + fn print_token(t: String) { + print!("{t}"); + std::io::stdout().flush().unwrap(); + } + if chat_mode { let stop_sequence = message_prompt_prefix.unwrap_or_default().to_owned(); @@ -306,7 +304,7 @@ fn interactive( maximum_token_count: generate.num_predict, }, &mut Default::default(), - conversation_inference_callback(stop_sequence, &mut buf, print_token), + conversation_inference_callback(stop_sequence, print_token), ) } else { session.infer::( @@ -348,7 +346,7 @@ fn interactive( .map(|template| process_prompt(template, &line)) .unwrap_or_else(|| { message_prompt_prefix - .map(|prefix| format!("{} {}", prefix, line)) + .map(|prefix| format!("{}{}", prefix, line)) .unwrap_or_else(|| line) }); diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 5db4b71e..e310ffa4 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -893,32 +893,37 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( } } -/// An [InferenceResponse] callback that will halt inference when a stop_sequence is generated. +/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated. /// This callback is used in [InferenceSession::infer] in chat_mode. pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( stop_sequence: String, - buf: &'a mut String, - mut print_token: impl FnMut(String) + 'a, + mut callback: impl FnMut(String) + 'a, ) -> impl FnMut(InferenceResponse) -> Result + 'a { + let mut stop_sequence_buf = String::new(); move |resp| match resp { - InferenceResponse::InferredToken(t) => { - let mut reverse_buf = buf.clone(); - reverse_buf.push_str(t.as_str()); - if stop_sequence.as_str().eq(reverse_buf.as_str()) { - buf.clear(); + InferenceResponse::InferredToken(token) => { + // We've generated a token, so we need to check if it's contained in the stop sequence. + let mut buf = stop_sequence_buf.clone(); + buf.push_str(&token); + + if buf.starts_with(&stop_sequence) { + // We've generated the stop sequence, so we're done. + // Note that this will contain the extra tokens that were generated after the stop sequence, + // which may affect generation. This is non-ideal, but it's the best we can do without + // modifying the model. + stop_sequence_buf.clear(); return Ok(InferenceFeedback::Halt); - } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { - buf.push_str(t.as_str()); + } else if stop_sequence.starts_with(&buf) { + // We've generated a prefix of the stop sequence, so we need to keep buffering. + stop_sequence_buf = buf; return Ok(InferenceFeedback::Continue); } - if buf.is_empty() { - print_token(t); - Ok(InferenceFeedback::Continue) - } else { - print_token(reverse_buf); - Ok(InferenceFeedback::Continue) - } + // We've generated a token that isn't part of the stop sequence, so we can + // pass it to the callback. + stop_sequence_buf.clear(); + callback(buf); + Ok(InferenceFeedback::Continue) } InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), _ => Ok(InferenceFeedback::Continue), diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index a67810cc..ecf73de4 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -77,7 +77,6 @@ fn main() { let mut rng = rand::thread_rng(); let mut res = llm::InferenceStats::default(); - let mut buf = String::new(); loop { println!(); @@ -98,11 +97,7 @@ fn main() { maximum_token_count: None, }, &mut Default::default(), - conversation_inference_callback( - String::from(user_name), - &mut buf, - print_token, - ), + conversation_inference_callback(format!("{character_name}:"), print_token), ) .unwrap_or_else(|e| panic!("{e}"));