Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

allow chat to halt new token generation on stop_sequence #364

Merged
merged 13 commits into from
Jul 12, 2023
68 changes: 48 additions & 20 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use std::{
use clap::Parser;
use cli_args::Args;
use color_eyre::eyre::{bail, Context, ContextCompat, Result};
use llm::{InferenceError, InferenceFeedback, InferenceResponse, InferenceSession};
use llm::{
conversation_inference_callback, InferenceError, InferenceFeedback, InferenceResponse,
InferenceSession,
};
use rustyline::{
error::ReadlineError,
history::DefaultHistory,
Expand Down Expand Up @@ -256,6 +259,17 @@ fn interactive(
let parameters = generate.inference_parameters(model.eot_token_id());
let mut rng = generate.rng();

let stop_sequence = message_prompt_template
Copy link
Contributor Author

@averypelle averypelle Jul 9, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am assuming the message_prompt_template is something like:

User: {{PROMPT}}

But very open to suggestions here on how to make this more robust

.map(|s| s.replace("{{PROMPT}}", "").trim().to_owned())
.unwrap_or_default();

let mut buf = String::new();

fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();
}

fn session_ends_with_newline(session: &InferenceSession) -> bool {
session
.decoded_tokens()
Expand Down Expand Up @@ -283,26 +297,40 @@ fn interactive(
};
sp.clear();

session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();
if chat_mode {
session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
conversation_inference_callback(stop_sequence.clone(), &mut buf, print_token),
)
} else {
session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
averypelle marked this conversation as resolved.
Show resolved Hide resolved
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print_token(t);

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
)
Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
)
}
};

let mut rl = rustyline::Editor::<LineContinuationValidator, DefaultHistory>::new()?;
Expand Down
33 changes: 33 additions & 0 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -853,3 +853,36 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
None => Ok(InferenceFeedback::Continue),
}
}

/// Callback to be passed to [InferenceSession::infer] that will print the
philpax marked this conversation as resolved.
Show resolved Hide resolved
/// token to stdout and will halt execution when the stop sequence is encountered.
/// Only to be used for chat mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
stop_sequence: String,
buf: &'a mut String,
mut print_token: impl FnMut(String) + 'a,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
move |resp| match resp {
InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
return Ok(InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
return Ok(InferenceFeedback::Continue);
}

if buf.is_empty() {
print_token(t);
Ok(InferenceFeedback::Continue)
} else {
print_token(reverse_buf);
Ok(InferenceFeedback::Continue)
}
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
}
}
7 changes: 4 additions & 3 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ pub use ggml;
pub use ggml::Type as ElementType;

pub use inference_session::{
feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest,
InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot,
InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError,
conversation_inference_callback, feed_prompt_callback, GraphOutputs, InferenceError,
InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession,
InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats,
ModelKVMemoryType, SnapshotError,
};
pub use loader::{
load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic,
Expand Down
46 changes: 13 additions & 33 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use clap::Parser;
use llm_base::conversation_inference_callback;
use rustyline::error::ReadlineError;
use std::{convert::Infallible, io::Write, path::PathBuf};

Expand Down Expand Up @@ -62,7 +63,11 @@ fn main() {
&mut Default::default(),
llm::feed_prompt_callback(|resp| match resp {
llm::InferenceResponse::PromptToken(t)
| llm::InferenceResponse::InferredToken(t) => print_token(t),
| llm::InferenceResponse::InferredToken(t) => {
print_token(t);

Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
}),
)
Expand All @@ -81,7 +86,7 @@ fn main() {
match readline {
Ok(line) => {
let stats = session
.infer(
.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
Expand All @@ -93,7 +98,11 @@ fn main() {
maximum_token_count: None,
},
&mut Default::default(),
inference_callback(String::from(user_name), &mut buf),
conversation_inference_callback(
String::from(user_name),
&mut buf,
print_token,
),
)
.unwrap_or_else(|e| panic!("{e}"));

Expand All @@ -116,36 +125,7 @@ fn main() {
println!("\n\nInference stats:\n{res}");
}

fn inference_callback(
stop_sequence: String,
buf: &mut String,
) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + '_ {
move |resp| match resp {
llm::InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
return Ok(llm::InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
return Ok(llm::InferenceFeedback::Continue);
}

if buf.is_empty() {
print_token(t)
} else {
print_token(reverse_buf)
}
}
llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
_ => Ok(llm::InferenceFeedback::Continue),
}
}

fn print_token(t: String) -> Result<llm::InferenceFeedback, Infallible> {
fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(llm::InferenceFeedback::Continue)
}
16 changes: 8 additions & 8 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ use std::{
// Try not to expose too many GGML details here.
// This is the "user-facing" API, and GGML may not always be our backend.
pub use llm_base::{
feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout,
quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters,
InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse,
InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef,
InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model,
ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress,
Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer,
TokenizerSource,
conversation_inference_callback, feed_prompt_callback, ggml::format as ggml_format, load,
load_progress_callback_stdout, quantize, samplers, ElementType, FileType, FileTypeFormat,
FormatMagic, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters,
InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig,
InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel,
LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest,
Prompt, QuantizeError, QuantizeProgress, Sampler, SnapshotError, TokenBias, TokenId,
TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource,
};

use serde::Serialize;
Expand Down