Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

allow chat to halt new token generation on stop_sequence #364

Merged
merged 13 commits into from
Jul 12, 2023
70 changes: 35 additions & 35 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,51 +179,51 @@ pub struct Chat {
#[arg(long, short = 'f')]
pub prelude_prompt_file: PathBuf,

/// The per-message prompt to use.
/// The per-message prefix to be prepended to the user's message.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
/// The `{{PROMPT}}` will automatically be appended to this prefix.
#[arg(long, short = 'p')]
pub message_prompt: Option<String>,
pub message_prompt_prefix: Option<String>,

/// The file to read the per-message prompt from.
/// The file containing the per-message prefix to be prepended to the user's message.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
/// The `{{PROMPT}}` will automatically be appended to this prefix.
#[arg(long, short = 'q')]
pub message_prompt_file: Option<PathBuf>,
pub message_prompt_prefix_file: Option<PathBuf>,

#[command(flatten)]
pub generate: Generate,
}
impl Chat {
pub fn message_prompt(&self) -> eyre::Result<String> {
if self.message_prompt.is_some() && self.message_prompt_file.is_some() {
eyre::bail!("Cannot specify both --message-prompt and --message-prompt-file")
}

if let Some(message_prompt_file) = &self.message_prompt_file {
read_prompt_file(message_prompt_file).and_then(|prompt| {
prompt
.contains("{{PROMPT}}")
.then_some(prompt)
.ok_or_else(|| {
eyre::eyre!(
"Message prompt file must contain a `{{{{PROMPT}}}}` placeholder, but it does not"
)
})
})
} else if let Some(message_prompt) = &self.message_prompt {
message_prompt
.contains("{{PROMPT}}")
.then(|| message_prompt.clone())
.ok_or_else(|| {
eyre::eyre!(
"Message prompt must contain a `{{{{PROMPT}}}}` placeholder, but it does not"
)
})
} else {
eyre::bail!("Must specify either --message-prompt or --message-prompt-file")
pub fn message_prompt_prefix(&self) -> eyre::Result<String> {
const MESSAGE_PROMPT_PREFIX_ERROR: &str = concat!(
"Message prompt prefix must not contain a `{{PROMPT}}` placeholder. ",
"The prompt will be automatically appended to the prefix."
);

match (
&self.message_prompt_prefix,
&self.message_prompt_prefix_file,
) {
(None, None) => eyre::bail!(
"Must specify either --message-prompt-prefix or --message-prompt-prefix-file"
),
(Some(_), Some(_)) => eyre::bail!(
"Cannot specify both --message-prompt-prefix and --message-prompt-prefix-file"
),
(Some(message_prompt_prefix), None) => {
if message_prompt_prefix.contains("{{PROMPT}}") {
eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}");
}
Ok(message_prompt_prefix.clone())
}
(None, Some(message_prompt_prefix_file)) => {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wow very cool!

let prompt = read_prompt_file(message_prompt_prefix_file)?;
if prompt.contains("{{PROMPT}}") {
eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}");
}
Ok(prompt)
}
}
}
}
Expand Down
83 changes: 55 additions & 28 deletions binaries/llm-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ use std::{
use clap::Parser;
use cli_args::Args;
use color_eyre::eyre::{bail, Context, ContextCompat, Result};
use llm::{InferenceError, InferenceFeedback, InferenceResponse, InferenceSession};
use llm::{
conversation_inference_callback, InferenceError, InferenceFeedback, InferenceResponse,
InferenceSession,
};
use rustyline::{
error::ReadlineError,
history::DefaultHistory,
Expand Down Expand Up @@ -235,7 +238,7 @@ fn chat(args: &cli_args::Chat) -> Result<()> {
&args.model_load,
true,
Some(std::fs::read_to_string(&args.prelude_prompt_file)?.as_str()),
Some(&args.message_prompt()?),
Some(&args.message_prompt_prefix()?),
)
}

Expand All @@ -244,7 +247,7 @@ fn interactive(
model_load: &cli_args::ModelLoad,
chat_mode: bool,
mut initial_prompt_template: Option<&str>,
message_prompt_template: Option<&str>,
message_prompt_prefix: Option<&str>,
) -> Result<()> {
let inference_session_config = generate.inference_session_config();
let model = model_load.load(generate.use_gpu)?;
Expand All @@ -261,7 +264,7 @@ fn interactive(
.decoded_tokens()
.last()
.map(|t| *t == b'\n')
.unwrap_or(false)
.unwrap_or(true)
}

let mut infer = |session: &mut InferenceSession, mut prompt: String| {
Expand All @@ -283,26 +286,47 @@ fn interactive(
};
sp.clear();

session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print!("{t}");
std::io::stdout().flush().unwrap();
fn print_token(t: String) {
print!("{t}");
std::io::stdout().flush().unwrap();
}

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
)
if chat_mode {
let stop_sequence = message_prompt_prefix.unwrap_or_default().to_owned();

session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
conversation_inference_callback(stop_sequence, print_token),
)
} else {
session.infer::<Infallible>(
model.as_ref(),
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
averypelle marked this conversation as resolved.
Show resolved Hide resolved
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| match r {
InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => {
print_token(t);

Ok(InferenceFeedback::Continue)
}
_ => Ok(InferenceFeedback::Continue),
},
)
}
};

let mut rl = rustyline::Editor::<LineContinuationValidator, DefaultHistory>::new()?;
Expand All @@ -316,12 +340,15 @@ fn interactive(
let line = raw_line.replace("\\\n", "\n");

// Use the initial prompt template for the first inference,
// and then switch to the message prompt template afterwards
// and then switch to the message prompt prefix afterwards.
let mut prompt = initial_prompt_template
.take()
.or(message_prompt_template)
.map(|pf| process_prompt(pf, &line))
.unwrap_or(line);
.map(|template| process_prompt(template, &line))
.unwrap_or_else(|| {
message_prompt_prefix
.map(|prefix| format!("{}{}", prefix, line))
.unwrap_or_else(|| line)
});

// Add a newline to the end of the prompt if it doesn't end with one in chat mode
if chat_mode && !prompt.ends_with('\n') {
Expand All @@ -338,7 +365,7 @@ fn interactive(
}

// Reload session in REPL mode
if message_prompt_template.is_none() {
if !chat_mode {
session = recreate_session();
}
}
Expand Down
63 changes: 51 additions & 12 deletions crates/llm-base/src/inference_session.rs
Original file line number Diff line number Diff line change
Expand Up @@ -280,7 +280,7 @@ impl InferenceSession {
}

/// Feed a prompt to the model for this session.
pub fn feed_prompt<'a, E: std::error::Error + 'static, P: Into<Prompt<'a>>>(
pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into<Prompt<'a>>>(
&mut self,
model: &dyn Model,
params: &InferenceParameters,
Expand Down Expand Up @@ -407,7 +407,7 @@ impl InferenceSession {
/// generated (specified by [InferenceRequest::maximum_token_count]).
///
/// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token].
pub fn infer<E: std::error::Error + 'static>(
pub fn infer<E: std::error::Error + Send + Sync + 'static>(
&mut self,
model: &dyn Model,
rng: &mut impl rand::Rng,
Expand Down Expand Up @@ -438,14 +438,16 @@ impl InferenceSession {
let parameters = request.parameters;

// Feed the initial prompt through the transformer, to update its
// context window with new data.
self.feed_prompt(
model,
parameters,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
// context window with new data, if necessary.
if !request.prompt.is_empty() {
self.feed_prompt(
model,
parameters,
request.prompt,
output_request,
feed_prompt_callback(&mut callback),
)?;
}
stats.feed_prompt_duration = start_at.elapsed().unwrap();
stats.prompt_tokens = self.n_past;

Expand Down Expand Up @@ -661,7 +663,7 @@ pub enum InferenceError {
EndOfText,
#[error("the user-specified callback returned an error")]
/// The user-specified callback returned an error.
UserCallback(Box<dyn std::error::Error>),
UserCallback(Box<dyn std::error::Error + Send + Sync>),
}

#[derive(Error, Debug)]
Expand Down Expand Up @@ -883,7 +885,7 @@ pub enum InferenceFeedback {

/// Adapt an [InferenceResponse] callback so that it can be used in a call to
/// [InferenceSession::feed_prompt].
pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
pub fn feed_prompt_callback<'a, E: std::error::Error + Send + Sync + 'static>(
mut callback: impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a,
) -> impl FnMut(&[u8]) -> Result<InferenceFeedback, E> + 'a {
let mut buffer = TokenUtf8Buffer::new();
Expand All @@ -892,3 +894,40 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>(
None => Ok(InferenceFeedback::Continue),
}
}

/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated.
/// This callback is used in [InferenceSession::infer] in chat_mode.
pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do these do?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Rust, objects get the Send trait if they can be sent across threads, and Sync if they can be used by multiple threads (you can see more details here).

I needed to add this because eyre, which we use for error reporting in the CLI, expects the error from infer to be Send + Sync. The error is passed down from callback to infer, so the trait requirements need to be updated across the library.

stop_sequence: &'a str,
mut callback: impl FnMut(String) + 'a,
) -> impl FnMut(InferenceResponse) -> Result<InferenceFeedback, E> + 'a {
let mut stop_sequence_buf = String::new();
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

interesting, scoping this buffer to the function seems a lot better!

move |resp| match resp {
InferenceResponse::InferredToken(token) => {
// We've generated a token, so we need to check if it's contained in the stop sequence.
let mut buf = stop_sequence_buf.clone();
buf.push_str(&token);

if buf.starts_with(stop_sequence) {
// We've generated the stop sequence, so we're done.
// Note that this will contain the extra tokens that were generated after the stop sequence,
// which may affect generation. This is non-ideal, but it's the best we can do without
// modifying the model.
stop_sequence_buf.clear();
return Ok(InferenceFeedback::Halt);
} else if stop_sequence.starts_with(&buf) {
// We've generated a prefix of the stop sequence, so we need to keep buffering.
stop_sequence_buf = buf;
return Ok(InferenceFeedback::Continue);
}

// We've generated a token that isn't part of the stop sequence, so we can
// pass it to the callback.
stop_sequence_buf.clear();
callback(buf);
Ok(InferenceFeedback::Continue)
}
InferenceResponse::EotToken => Ok(InferenceFeedback::Halt),
_ => Ok(InferenceFeedback::Continue),
}
}
7 changes: 4 additions & 3 deletions crates/llm-base/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@ pub use ggml;
pub use ggml::Type as ElementType;

pub use inference_session::{
feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest,
InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot,
InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, RewindError, SnapshotError,
conversation_inference_callback, feed_prompt_callback, GraphOutputs, InferenceError,
InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession,
InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats,
ModelKVMemoryType, RewindError, SnapshotError,
};
pub use loader::{
load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic,
Expand Down
8 changes: 8 additions & 0 deletions crates/llm-base/src/tokenizer/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,14 @@ impl Prompt<'_> {
}
})
}

/// Returns whether this prompt is empty.
pub fn is_empty(&self) -> bool {
match self {
Self::Text(text) => text.is_empty(),
Self::Tokens(tokens) => tokens.is_empty(),
}
}
}
impl<'a> Default for Prompt<'a> {
fn default() -> Self {
Expand Down
Loading