Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
fix(llm): clarify conversation_inference_callback
Browse files Browse the repository at this point in the history
  • Loading branch information
philpax committed Jul 12, 2023
1 parent a0ad8b4 commit 138263a
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 32 deletions.
16 changes: 7 additions & 9 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -259,13 +259,6 @@ fn interactive(
let parameters = generate.inference_parameters(model.eot_token_id());
let mut rng = generate.rng();

let mut buf = String::new();

fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();
}

fn session_ends_with_newline(session: &InferenceSession) -> bool {
session
.decoded_tokens()
Expand Down Expand Up @@ -293,6 +286,11 @@ fn interactive(
};
sp.clear();

fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();
}

if chat_mode {
let stop_sequence = message_prompt_prefix.unwrap_or_default().to_owned();

Expand All @@ -306,7 +304,7 @@ fn interactive(
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
conversation_inference_callback(stop_sequence, &mut buf, print_token),
conversation_inference_callback(stop_sequence, print_token),
)
} else {
session.infer::<Infallible>(
Expand Down Expand Up @@ -348,7 +346,7 @@ fn interactive(
.map(|template| process_prompt(template, &line))
.unwrap_or_else(|| {
message_prompt_prefix
.map(|prefix| format!("{} {}", prefix, line))
.map(|prefix| format!("{}{}", prefix, line))
.unwrap_or_else(|| line)
});

Expand Down
39 changes: 22 additions & 17 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -893,32 +893,37 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
}
}

/// An [InferenceResponse] callback that will halt inference when a stop_sequence is generated.
/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated.
/// This callback is used in [InferenceSession::infer] in chat_mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
stop_sequence: String,
buf: &'a mut String,
mut print_token: impl FnMut(String) + 'a,
mut callback: impl FnMut(String) + 'a,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
let mut stop_sequence_buf = String::new();
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
InferenceResponse::InferredToken(token) => {
// We've generated a token, so we need to check if it's contained in the stop sequence.
let mut buf = stop_sequence_buf.clone();
buf.push_str(&token);

if buf.starts_with(&stop_sequence) {
// We've generated the stop sequence, so we're done.
// Note that this will contain the extra tokens that were generated after the stop sequence,
// which may affect generation. This is non-ideal, but it's the best we can do without
// modifying the model.
stop_sequence_buf.clear();
return Ok(InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
} else if stop_sequence.starts_with(&buf) {
// We've generated a prefix of the stop sequence, so we need to keep buffering.
stop_sequence_buf = buf;
return Ok(InferenceFeedback::Continue);
}

if buf.is_empty() {
print_token(t);
Ok(InferenceFeedback::Continue)
} else {
print_token(reverse_buf);
Ok(InferenceFeedback::Continue)
}
// We've generated a token that isn't part of the stop sequence, so we can
// pass it to the callback.
stop_sequence_buf.clear();
callback(buf);
Ok(InferenceFeedback::Continue)
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
Expand Down
7 changes: 1 addition & 6 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ fn main() {

let mut rng = rand::thread_rng();
let mut res = llm::InferenceStats::default();
let mut buf = String::new();

loop {
println!();
Expand All @@ -98,11 +97,7 @@ fn main() {
maximum_token_count: None,
},
&mut Default::default(),
conversation_inference_callback(
String::from(user_name),
&mut buf,
print_token,
),
conversation_inference_callback(format!("{character_name}:"), print_token),
)
.unwrap_or_else(|e| panic!("{e}"));

Expand Down

0 comments on commit 138263a

Please sign in to comment.