Skip to content
This repository has been archived by the owner on Jun 24, 2024. It is now read-only.

Commit

Permalink
Merge pull request #364 from averypelle/fix/chat-halting
Browse files Browse the repository at this point in the history
allow chat to halt new token generation on `stop_sequence`
  • Loading branch information
philpax authored Jul 12, 2023
2 parents bc9f2fe + 34a8c68 commit fc1c052
Show file tree
Hide file tree
Showing 13 changed files with 383 additions and 307 deletions.
70 changes: 35 additions & 35 deletions binaries/llm-cli/src/cli_args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,51 +179,51 @@ pub struct Chat {
#[arg(long, short = 'f')]
pub prelude_prompt_file: PathBuf,

/// The per-message prompt to use.
/// The per-message prefix to be prepended to the user's message.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
/// The `{{PROMPT}}` will automatically be appended to this prefix.
#[arg(long, short = 'p')]
pub message_prompt: Option<String>,
pub message_prompt_prefix: Option<String>,

/// The file to read the per-message prompt from.
/// The file containing the per-message prefix to be prepended to the user's message.
///
/// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the
/// user's message.
/// The `{{PROMPT}}` will automatically be appended to this prefix.
#[arg(long, short = 'q')]
pub message_prompt_file: Option<PathBuf>,
pub message_prompt_prefix_file: Option<PathBuf>,

#[command(flatten)]
pub generate: Generate,
}
impl Chat {
pub fn message_prompt(&self) -> eyre::Result<String> {
if self.message_prompt.is_some() && self.message_prompt_file.is_some() {
eyre::bail!("Cannot specify both --message-prompt and --message-prompt-file")
}

if let Some(message_prompt_file) = &self.message_prompt_file {
read_prompt_file(message_prompt_file).and_then(|prompt| {
prompt
.contains("{{PROMPT}}")
.then_some(prompt)
.ok_or_else(|| {
eyre::eyre!(
"Message prompt file must contain a `{{{{PROMPT}}}}` placeholder, but it does not"
)
})
})
} else if let Some(message_prompt) = &self.message_prompt {
message_prompt
.contains("{{PROMPT}}")
.then(|| message_prompt.clone())
.ok_or_else(|| {
eyre::eyre!(
"Message prompt must contain a `{{{{PROMPT}}}}` placeholder, but it does not"
)
})
} else {
eyre::bail!("Must specify either --message-prompt or --message-prompt-file")
pub fn message_prompt_prefix(&self) -> eyre::Result<String> {
const MESSAGE_PROMPT_PREFIX_ERROR: &str = concat!(
"Message prompt prefix must not contain a `{{PROMPT}}` placeholder. ",
"The prompt will be automatically appended to the prefix."
);

match (
&self.message_prompt_prefix,
&self.message_prompt_prefix_file,
) {
(None, None) => eyre::bail!(
"Must specify either --message-prompt-prefix or --message-prompt-prefix-file"
),
(Some(_), Some(_)) => eyre::bail!(
"Cannot specify both --message-prompt-prefix and --message-prompt-prefix-file"
),
(Some(message_prompt_prefix), None) => {
if message_prompt_prefix.contains("{{PROMPT}}") {
eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}");
}
Ok(message_prompt_prefix.clone())
}
(None, Some(message_prompt_prefix_file)) => {
let prompt = read_prompt_file(message_prompt_prefix_file)?;
if prompt.contains("{{PROMPT}}") {
eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}");
}
Ok(prompt)
}
}
}
}
Expand Down
223 changes: 223 additions & 0 deletions binaries/llm-cli/src/interactive.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
use std::convert::Infallible;

use color_eyre::eyre;
use rustyline::{
error::ReadlineError,
history::DefaultHistory,
validate::{ValidationContext, ValidationResult, Validator},
Cmd, Completer, Helper, Highlighter, Hinter, KeyCode, KeyEvent, Modifiers,
};

use crate::{
cli_args::{Chat, Repl},
snapshot, util,
};

pub fn repl(
Repl {
generate,
model_load,
prompt_file,
}: &Repl,
) -> eyre::Result<()> {
let (inference_session_config, parameters, model, mut rng) =
initialize_common_state(generate, model_load)?;

let template = prompt_file.contents()?;

let model = model.as_ref();
let mut session = create_session(model, inference_session_config);
readline_loop(|raw_line| {
let line = raw_line.replace("\\\n", "\n");

let prompt = template
.as_deref()
.map(|template| util::process_prompt(template, &line))
.unwrap_or(line);
feed_prompt_with_spinner(model, &mut session, &parameters, prompt)?;

session.infer::<Infallible>(
model,
&mut rng,
&llm::InferenceRequest {
prompt: "".into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
|r| {
if let llm::InferenceResponse::InferredToken(t) = r {
util::print_token(t);
}
Ok(llm::InferenceFeedback::Continue)
},
)?;

if !session_ends_with_newline(&session) {
println!();
}
session = create_session(model, inference_session_config);

Ok(())
})
}

pub fn chat(args: &Chat) -> eyre::Result<()> {
let Chat {
model_load,
prelude_prompt_file,
generate,
..
} = args;

let (inference_session_config, parameters, model, mut rng) =
initialize_common_state(generate, model_load)?;

let prelude_prompt = std::fs::read_to_string(prelude_prompt_file)?;
let message_prompt_prefix = args.message_prompt_prefix()?;

let model = model.as_ref();
let mut session = create_session(model, inference_session_config);
feed_prompt_with_spinner(model, &mut session, &parameters, prelude_prompt)?;

readline_loop(|raw_line| {
let prompt = {
let line = raw_line.replace("\\\n", "\n");
let mut prompt = format!("{message_prompt_prefix}{line}");
// Add a newline to the end of the prompt if it doesn't end with one
if !prompt.ends_with('\n') {
prompt.push('\n');
}
prompt
};

session.infer::<Infallible>(
model,
&mut rng,
&llm::InferenceRequest {
prompt: (&prompt).into(),
parameters: &parameters,
play_back_previous_tokens: false,
maximum_token_count: generate.num_predict,
},
&mut Default::default(),
llm::conversation_inference_callback(&message_prompt_prefix, util::print_token),
)?;

if !session_ends_with_newline(&session) {
println!();
}

Ok(())
})
}

fn initialize_common_state(
generate: &crate::cli_args::Generate,
model_load: &crate::cli_args::ModelLoad,
) -> eyre::Result<(
llm::InferenceSessionConfig,
llm::InferenceParameters,
Box<dyn llm::Model>,
rand::rngs::StdRng,
)> {
let model = model_load.load(generate.use_gpu)?;
Ok((
generate.inference_session_config(),
generate.inference_parameters(model.eot_token_id()),
model,
generate.rng(),
))
}

fn feed_prompt_with_spinner(
model: &dyn llm::Model,
session: &mut llm::InferenceSession,
parameters: &llm::InferenceParameters,
mut prompt: String,
) -> eyre::Result<()> {
// Add a newline to the beginning of the prompt if the last character in the session is not a newline
if !session_ends_with_newline(session) {
prompt.insert(0, '\n');
}

let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None);
let result = session.feed_prompt(
model,
parameters,
&prompt,
// OutputRequest
&mut Default::default(),
|_| Ok::<_, Infallible>(llm::InferenceFeedback::Continue),
);
sp.clear();

Ok(result?)
}

fn create_session(
model: &dyn llm::Model,
inference_session_config: llm::InferenceSessionConfig,
) -> llm::InferenceSession {
snapshot::read_or_create_session(model, None, None, inference_session_config).0
}

fn session_ends_with_newline(session: &llm::InferenceSession) -> bool {
session
.decoded_tokens()
.last()
.map(|t| *t == b'\n')
.unwrap_or(true)
}

fn readline_loop(mut body: impl FnMut(String) -> eyre::Result<()>) -> eyre::Result<()> {
let mut rl = rustyline::Editor::<LineContinuationValidator, DefaultHistory>::new()?;
rl.set_helper(Some(LineContinuationValidator));
rl.bind_sequence(force_newline_event_seq(), Cmd::Newline);

loop {
match rl.readline(">> ") {
Ok(raw_line) => {
if let Err(err) = body(raw_line) {
log::error!("{err}");
break;
}
}
Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => {
break;
}
Err(err) => {
log::error!("{err}");
break;
}
}
}

Ok(())
}

#[cfg(not(windows))]
fn force_newline_event_seq() -> KeyEvent {
KeyEvent(KeyCode::Enter, Modifiers::ALT)
}

// On Windows, `SHIFT+ENTER` is the key sequence for forcing a newline. This is
// because `ALT+ENTER` typically maximizes the window.
#[cfg(windows)]
fn force_newline_event_seq() -> KeyEvent {
KeyEvent(KeyCode::Enter, Modifiers::SHIFT)
}

#[derive(Completer, Helper, Highlighter, Hinter, Debug, Clone, Copy)]
struct LineContinuationValidator;

impl Validator for LineContinuationValidator {
fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result<ValidationResult> {
if ctx.input().ends_with('\\') {
Ok(ValidationResult::Incomplete)
} else {
Ok(ValidationResult::Valid(None))
}
}
}
Loading

0 comments on commit fc1c052

Please sign in to comment.