Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

allow chat to halt new token generation on stop_sequence #364

Merged
merged 13 commits into from
Jul 12, 2023
60 changes: 30 additions & 30 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,51 +179,51 @@ pub struct Chat {
#[arg(long, short = 'f')]
pub prelude_prompt_file: PathBuf,

/// The per-message prompt to use.
/// The per-message prefix to be prepended to the user's message.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
/// The `{{PROMPT}}` will automatically be appended to this prefix.
#[arg(long, short = 'p')]
pub message_prompt: Option<String>,
pub message_prompt_prefix: Option<String>,

/// The file to read the per-message prompt from.
/// The file containing the per-message prefix to be prepended to the user's message.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
/// The `{{PROMPT}}` will automatically be appended to this prefix.
#[arg(long, short = 'q')]
pub message_prompt_file: Option<PathBuf>,
pub message_prompt_prefix_file: Option<PathBuf>,

#[command(flatten)]
pub generate: Generate,
}
impl Chat {
pub fn message_prompt(&self) -> eyre::Result<String> {
if self.message_prompt.is_some() && self.message_prompt_file.is_some() {
eyre::bail!("Cannot specify both --message-prompt and --message-prompt-file")
pub fn message_prompt_prefix(&self) -> eyre::Result<String> {
if self.message_prompt_prefix.is_some() && self.message_prompt_prefix_file.is_some() {
eyre::bail!(
"Cannot specify both --message-prompt-prefix and --message-prompt-prefix-file"
)
}

if let Some(message_prompt_file) = &self.message_prompt_file {
read_prompt_file(message_prompt_file).and_then(|prompt| {
prompt
.contains("{{PROMPT}}")
.then_some(prompt)
.ok_or_else(|| {
eyre::eyre!(
"Message prompt file must contain a `{{{{PROMPT}}}}` placeholder, but it does not"
)
})
if let Some(message_prompt_prefix_file) = &self.message_prompt_prefix_file {
read_prompt_file(message_prompt_prefix_file).and_then(|prompt| {
if prompt.contains("{{PROMPT}}") {
eyre::bail!(
"Message prompt file must not contain a `{{{{PROMPT}}}}` placeholder. The `{{PROMPT}}` will be automatically appended to the prefix."
)
} else {
Ok(prompt)
}
})
} else if let Some(message_prompt) = &self.message_prompt {
message_prompt
.contains("{{PROMPT}}")
.then(|| message_prompt.clone())
.ok_or_else(|| {
eyre::eyre!(
"Message prompt must contain a `{{{{PROMPT}}}}` placeholder, but it does not"
} else if let Some(message_prompt_prefix) = &self.message_prompt_prefix {
if message_prompt_prefix.contains("{{PROMPT}}") {
eyre::bail!(
"Message prompt file must not contain a `{{{{PROMPT}}}}` placeholder. The `{{PROMPT}}` will be automatically appended to the prefix."
)
})
} else {
Ok(message_prompt_prefix.clone())
}
} else {
eyre::bail!("Must specify either --message-prompt or --message-prompt-file")
eyre::bail!(
"Must specify either --message-prompt-prefix or --message-prompt-prefix-file"
)
}
}
}
Expand Down
83 changes: 55 additions & 28 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use std::{
use clap::Parser;
use cli_args::Args;
use color_eyre::eyre::{bail, Context, ContextCompat, Result};
use llm::{InferenceError, InferenceFeedback, InferenceResponse, InferenceSession};
use llm::{
conversation_inference_callback, InferenceError, InferenceFeedback, InferenceResponse,
InferenceSession,
};
use rustyline::{
error::ReadlineError,
history::DefaultHistory,
Expand Down Expand Up @@ -235,7 +238,7 @@ fn chat(args: &cli_args::Chat) -> Result<()> {
&args.model_load,
true,
Some(std::fs::read_to_string(&args.prelude_prompt_file)?.as_str()),
Some(&args.message_prompt()?),
Some(&args.message_prompt_prefix()?),
)
}

Expand All @@ -244,7 +247,7 @@ fn interactive(
model_load: &cli_args::ModelLoad,
chat_mode: bool,
mut initial_prompt_template: Option<&str>,
message_prompt_template: Option<&str>,
message_prompt_prefix: Option<&str>,
) -> Result<()> {
let inference_session_config = generate.inference_session_config();
let model = model_load.load(generate.use_gpu)?;
Expand All @@ -261,7 +264,7 @@ fn interactive(
.decoded_tokens()
.last()
.map(|t| *t == b'\n')
.unwrap_or(false)
.unwrap_or(true)
}

let mut infer = |session: &mut InferenceSession, mut prompt: String| {
Expand All @@ -283,26 +286,47 @@ fn interactive(
};
sp.clear();

session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();
fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();
}

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
)
if chat_mode {
let stop_sequence = message_prompt_prefix.unwrap_or_default().to_owned();

session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
conversation_inference_callback(stop_sequence, print_token),
)
} else {
session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
averypelle marked this conversation as resolved.
Show resolved Hide resolved
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print_token(t);

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
)
}
};

let mut rl = rustyline::Editor::<LineContinuationValidator, DefaultHistory>::new()?;
Expand All @@ -316,12 +340,15 @@ fn interactive(
let line = raw_line.replace("\\\n", "\n");

// Use the initial prompt template for the first inference,
// and then switch to the message prompt template afterwards
// and then switch to the message prompt prefix afterwards.
let mut prompt = initial_prompt_template
.take()
.or(message_prompt_template)
.map(|pf| process_prompt(pf, &line))
.unwrap_or(line);
.map(|template| process_prompt(template, &line))
.unwrap_or_else(|| {
message_prompt_prefix
.map(|prefix| format!("{}{}", prefix, line))
.unwrap_or_else(|| line)
});

// Add a newline to the end of the prompt if it doesn't end with one in chat mode
if chat_mode && !prompt.ends_with('\n') {
Expand All @@ -338,7 +365,7 @@ fn interactive(
}

// Reload session in REPL mode
if message_prompt_template.is_none() {
if !chat_mode {
session = recreate_session();
}
}
Expand Down
37 changes: 37 additions & 0 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -892,3 +892,40 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
None => Ok(InferenceFeedback::Continue),
}
}

/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated.
/// This callback is used in [InferenceSession::infer] in chat_mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>(
stop_sequence: String,
mut callback: impl FnMut(String) + 'a,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
let mut stop_sequence_buf = String::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, scoping this buffer to the function seems a lot better!

move |resp| match resp {
InferenceResponse::InferredToken(token) => {
// We've generated a token, so we need to check if it's contained in the stop sequence.
let mut buf = stop_sequence_buf.clone();
buf.push_str(&token);

if buf.starts_with(&stop_sequence) {
// We've generated the stop sequence, so we're done.
// Note that this will contain the extra tokens that were generated after the stop sequence,
// which may affect generation. This is non-ideal, but it's the best we can do without
// modifying the model.
stop_sequence_buf.clear();
return Ok(InferenceFeedback::Halt);
} else if stop_sequence.starts_with(&buf) {
// We've generated a prefix of the stop sequence, so we need to keep buffering.
stop_sequence_buf = buf;
return Ok(InferenceFeedback::Continue);
}

// We've generated a token that isn't part of the stop sequence, so we can
// pass it to the callback.
stop_sequence_buf.clear();
callback(buf);
Ok(InferenceFeedback::Continue)
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
}
}
7 changes: 4 additions & 3 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ pub use ggml;
pub use ggml::Type as ElementType;

pub use inference_session::{
feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest,
InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot,
InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, RewindError, SnapshotError,
conversation_inference_callback, feed_prompt_callback, GraphOutputs, InferenceError,
InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession,
InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats,
ModelKVMemoryType, RewindError, SnapshotError,
};
pub use loader::{
load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic,
Expand Down
43 changes: 9 additions & 34 deletions crates/llm/examples/vicuna-chat.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use clap::Parser;
use llm_base::conversation_inference_callback;
use rustyline::error::ReadlineError;
use std::{convert::Infallible, io::Write, path::PathBuf};

Expand Down Expand Up @@ -62,7 +63,11 @@ fn main() {
&mut Default::default(),
llm::feed_prompt_callback(|resp| match resp {
llm::InferenceResponse::PromptToken(t)
| llm::InferenceResponse::InferredToken(t) => print_token(t),
| llm::InferenceResponse::InferredToken(t) => {
print_token(t);

Ok::<llm::InferenceFeedback, Infallible>(llm::InferenceFeedback::Continue)
}
_ => Ok(llm::InferenceFeedback::Continue),
}),
)
Expand All @@ -72,7 +77,6 @@ fn main() {

let mut rng = rand::thread_rng();
let mut res = llm::InferenceStats::default();
let mut buf = String::new();

loop {
println!();
Expand All @@ -81,7 +85,7 @@ fn main() {
match readline {
Ok(line) => {
let stats = session
.infer(
.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
Expand All @@ -93,7 +97,7 @@ fn main() {
maximum_token_count: None,
},
&mut Default::default(),
inference_callback(String::from(user_name), &mut buf),
conversation_inference_callback(format!("{character_name}:"), print_token),
)
.unwrap_or_else(|e| panic!("{e}"));

Expand All @@ -116,36 +120,7 @@ fn main() {
println!("\n\nInference stats:\n{res}");
}

fn inference_callback(
stop_sequence: String,
buf: &mut String,
) -> impl FnMut(llm::InferenceResponse) -> Result<llm::InferenceFeedback, Infallible> + '_ {
move |resp| match resp {
llm::InferenceResponse::InferredToken(t) => {
let mut reverse_buf = buf.clone();
reverse_buf.push_str(t.as_str());
if stop_sequence.as_str().eq(reverse_buf.as_str()) {
buf.clear();
return Ok(llm::InferenceFeedback::Halt);
} else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) {
buf.push_str(t.as_str());
return Ok(llm::InferenceFeedback::Continue);
}

if buf.is_empty() {
print_token(t)
} else {
print_token(reverse_buf)
}
}
llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt),
_ => Ok(llm::InferenceFeedback::Continue),
}
}

fn print_token(t: String) -> Result<llm::InferenceFeedback, Infallible> {
fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();

Ok(llm::InferenceFeedback::Continue)
}
16 changes: 8 additions & 8 deletions crates/llm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ use std::{
// Try not to expose too many GGML details here.
// This is the "user-facing" API, and GGML may not always be our backend.
pub use llm_base::{
feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout,
quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters,
InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse,
InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef,
InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model,
ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress,
RewindError, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError,
Tokenizer, TokenizerSource,
conversation_inference_callback, feed_prompt_callback, ggml::format as ggml_format, load,
load_progress_callback_stdout, quantize, samplers, ElementType, FileType, FileTypeFormat,
FormatMagic, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters,
InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig,
InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel,
LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest,
Prompt, QuantizeError, QuantizeProgress, RewindError, Sampler, SnapshotError, TokenBias,
TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource,
};

use serde::Serialize;
Expand Down