diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 022957e6..ca6cf9e4 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -179,51 +179,51 @@ pub struct Chat { #[arg(long, short = 'f')] pub prelude_prompt_file: PathBuf, - /// The per-message prompt to use. + /// The per-message prefix to be prepended to the user's message. /// - /// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the - /// user's message. + /// The `{{PROMPT}}` will automatically be appended to this prefix. #[arg(long, short = 'p')] - pub message_prompt: Option, + pub message_prompt_prefix: Option, - /// The file to read the per-message prompt from. + /// The file containing the per-message prefix to be prepended to the user's message. /// - /// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the - /// user's message. + /// The `{{PROMPT}}` will automatically be appended to this prefix. #[arg(long, short = 'q')] - pub message_prompt_file: Option, + pub message_prompt_prefix_file: Option, #[command(flatten)] pub generate: Generate, } impl Chat { - pub fn message_prompt(&self) -> eyre::Result { - if self.message_prompt.is_some() && self.message_prompt_file.is_some() { - eyre::bail!("Cannot specify both --message-prompt and --message-prompt-file") - } - - if let Some(message_prompt_file) = &self.message_prompt_file { - read_prompt_file(message_prompt_file).and_then(|prompt| { - prompt - .contains("{{PROMPT}}") - .then_some(prompt) - .ok_or_else(|| { - eyre::eyre!( - "Message prompt file must contain a `{{{{PROMPT}}}}` placeholder, but it does not" - ) - }) - }) - } else if let Some(message_prompt) = &self.message_prompt { - message_prompt - .contains("{{PROMPT}}") - .then(|| message_prompt.clone()) - .ok_or_else(|| { - eyre::eyre!( - "Message prompt must contain a `{{{{PROMPT}}}}` placeholder, but it does not" - ) - }) - } else { - eyre::bail!("Must specify either --message-prompt or --message-prompt-file") + pub fn message_prompt_prefix(&self) -> eyre::Result { + const MESSAGE_PROMPT_PREFIX_ERROR: &str = concat!( + "Message prompt prefix must not contain a `{{PROMPT}}` placeholder. ", + "The prompt will be automatically appended to the prefix." + ); + + match ( + &self.message_prompt_prefix, + &self.message_prompt_prefix_file, + ) { + (None, None) => eyre::bail!( + "Must specify either --message-prompt-prefix or --message-prompt-prefix-file" + ), + (Some(_), Some(_)) => eyre::bail!( + "Cannot specify both --message-prompt-prefix and --message-prompt-prefix-file" + ), + (Some(message_prompt_prefix), None) => { + if message_prompt_prefix.contains("{{PROMPT}}") { + eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}"); + } + Ok(message_prompt_prefix.clone()) + } + (None, Some(message_prompt_prefix_file)) => { + let prompt = read_prompt_file(message_prompt_prefix_file)?; + if prompt.contains("{{PROMPT}}") { + eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}"); + } + Ok(prompt) + } } } } diff --git a/binaries/llm-cli/src/interactive.rs b/binaries/llm-cli/src/interactive.rs new file mode 100644 index 00000000..8e5c71c4 --- /dev/null +++ b/binaries/llm-cli/src/interactive.rs @@ -0,0 +1,223 @@ +use std::convert::Infallible; + +use color_eyre::eyre; +use rustyline::{ + error::ReadlineError, + history::DefaultHistory, + validate::{ValidationContext, ValidationResult, Validator}, + Cmd, Completer, Helper, Highlighter, Hinter, KeyCode, KeyEvent, Modifiers, +}; + +use crate::{ + cli_args::{Chat, Repl}, + snapshot, util, +}; + +pub fn repl( + Repl { + generate, + model_load, + prompt_file, + }: &Repl, +) -> eyre::Result<()> { + let (inference_session_config, parameters, model, mut rng) = + initialize_common_state(generate, model_load)?; + + let template = prompt_file.contents()?; + + let model = model.as_ref(); + let mut session = create_session(model, inference_session_config); + readline_loop(|raw_line| { + let line = raw_line.replace("\\\n", "\n"); + + let prompt = template + .as_deref() + .map(|template| util::process_prompt(template, &line)) + .unwrap_or(line); + feed_prompt_with_spinner(model, &mut session, ¶meters, prompt)?; + + session.infer::( + model, + &mut rng, + &llm::InferenceRequest { + prompt: "".into(), + parameters: ¶meters, + play_back_previous_tokens: false, + maximum_token_count: generate.num_predict, + }, + &mut Default::default(), + |r| { + if let llm::InferenceResponse::InferredToken(t) = r { + util::print_token(t); + } + Ok(llm::InferenceFeedback::Continue) + }, + )?; + + if !session_ends_with_newline(&session) { + println!(); + } + session = create_session(model, inference_session_config); + + Ok(()) + }) +} + +pub fn chat(args: &Chat) -> eyre::Result<()> { + let Chat { + model_load, + prelude_prompt_file, + generate, + .. + } = args; + + let (inference_session_config, parameters, model, mut rng) = + initialize_common_state(generate, model_load)?; + + let prelude_prompt = std::fs::read_to_string(prelude_prompt_file)?; + let message_prompt_prefix = args.message_prompt_prefix()?; + + let model = model.as_ref(); + let mut session = create_session(model, inference_session_config); + feed_prompt_with_spinner(model, &mut session, ¶meters, prelude_prompt)?; + + readline_loop(|raw_line| { + let prompt = { + let line = raw_line.replace("\\\n", "\n"); + let mut prompt = format!("{message_prompt_prefix}{line}"); + // Add a newline to the end of the prompt if it doesn't end with one + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + prompt + }; + + session.infer::( + model, + &mut rng, + &llm::InferenceRequest { + prompt: (&prompt).into(), + parameters: ¶meters, + play_back_previous_tokens: false, + maximum_token_count: generate.num_predict, + }, + &mut Default::default(), + llm::conversation_inference_callback(&message_prompt_prefix, util::print_token), + )?; + + if !session_ends_with_newline(&session) { + println!(); + } + + Ok(()) + }) +} + +fn initialize_common_state( + generate: &crate::cli_args::Generate, + model_load: &crate::cli_args::ModelLoad, +) -> eyre::Result<( + llm::InferenceSessionConfig, + llm::InferenceParameters, + Box, + rand::rngs::StdRng, +)> { + let model = model_load.load(generate.use_gpu)?; + Ok(( + generate.inference_session_config(), + generate.inference_parameters(model.eot_token_id()), + model, + generate.rng(), + )) +} + +fn feed_prompt_with_spinner( + model: &dyn llm::Model, + session: &mut llm::InferenceSession, + parameters: &llm::InferenceParameters, + mut prompt: String, +) -> eyre::Result<()> { + // Add a newline to the beginning of the prompt if the last character in the session is not a newline + if !session_ends_with_newline(session) { + prompt.insert(0, '\n'); + } + + let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); + let result = session.feed_prompt( + model, + parameters, + &prompt, + // OutputRequest + &mut Default::default(), + |_| Ok::<_, Infallible>(llm::InferenceFeedback::Continue), + ); + sp.clear(); + + Ok(result?) +} + +fn create_session( + model: &dyn llm::Model, + inference_session_config: llm::InferenceSessionConfig, +) -> llm::InferenceSession { + snapshot::read_or_create_session(model, None, None, inference_session_config).0 +} + +fn session_ends_with_newline(session: &llm::InferenceSession) -> bool { + session + .decoded_tokens() + .last() + .map(|t| *t == b'\n') + .unwrap_or(true) +} + +fn readline_loop(mut body: impl FnMut(String) -> eyre::Result<()>) -> eyre::Result<()> { + let mut rl = rustyline::Editor::::new()?; + rl.set_helper(Some(LineContinuationValidator)); + rl.bind_sequence(force_newline_event_seq(), Cmd::Newline); + + loop { + match rl.readline(">> ") { + Ok(raw_line) => { + if let Err(err) = body(raw_line) { + log::error!("{err}"); + break; + } + } + Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { + break; + } + Err(err) => { + log::error!("{err}"); + break; + } + } + } + + Ok(()) +} + +#[cfg(not(windows))] +fn force_newline_event_seq() -> KeyEvent { + KeyEvent(KeyCode::Enter, Modifiers::ALT) +} + +// On Windows, `SHIFT+ENTER` is the key sequence for forcing a newline. This is +// because `ALT+ENTER` typically maximizes the window. +#[cfg(windows)] +fn force_newline_event_seq() -> KeyEvent { + KeyEvent(KeyCode::Enter, Modifiers::SHIFT) +} + +#[derive(Completer, Helper, Highlighter, Hinter, Debug, Clone, Copy)] +struct LineContinuationValidator; + +impl Validator for LineContinuationValidator { + fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result { + if ctx.input().ends_with('\\') { + Ok(ValidationResult::Incomplete) + } else { + Ok(ValidationResult::Valid(None)) + } + } +} diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 9cd8cb84..1c55824f 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -1,24 +1,19 @@ use std::{ convert::Infallible, fs::File, - io::{BufReader, BufWriter, Write}, + io::{BufReader, BufWriter}, }; use clap::Parser; use cli_args::Args; -use color_eyre::eyre::{bail, Context, ContextCompat, Result}; -use llm::{InferenceError, InferenceFeedback, InferenceResponse, InferenceSession}; -use rustyline::{ - error::ReadlineError, - history::DefaultHistory, - validate::{ValidationContext, ValidationResult, Validator}, - Cmd, Completer, Helper, Highlighter, Hinter, KeyCode, KeyEvent, Modifiers, -}; +use color_eyre::eyre::{self, Context, ContextCompat}; mod cli_args; +mod interactive; mod snapshot; +mod util; -fn main() -> Result<()> { +fn main() -> eyre::Result<()> { env_logger::builder() .filter_level(log::LevelFilter::Info) .parse_default_env() @@ -31,13 +26,13 @@ fn main() -> Result<()> { Args::Perplexity(args) => perplexity(&args), Args::Info(args) => info(&args), Args::PromptTokens(args) => prompt_tokens(&args), - Args::Repl(args) => repl(&args), - Args::Chat(args) => chat(&args), + Args::Repl(args) => interactive::repl(&args), + Args::Chat(args) => interactive::chat(&args), Args::Quantize(args) => quantize(&args), } } -fn infer(args: &cli_args::Infer) -> Result<()> { +fn infer(args: &cli_args::Infer) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let inference_session_config = args.generate.inference_session_config(); let model = args.model_load.load(args.generate.use_gpu)?; @@ -62,18 +57,13 @@ fn infer(args: &cli_args::Infer) -> Result<()> { }, // OutputRequest &mut Default::default(), - |r| match &r { - InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => { - if matches!(&r, InferenceResponse::PromptToken(_)) && args.hide_prompt { - return Ok(InferenceFeedback::Continue); - } - - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(InferenceFeedback::Continue) + |r| { + match r { + llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => util::print_token(t), + llm::InferenceResponse::InferredToken(t) => util::print_token(t), + _ => {} } - _ => Ok(InferenceFeedback::Continue), + Ok(llm::InferenceFeedback::Continue) }, ); println!(); @@ -86,13 +76,13 @@ fn infer(args: &cli_args::Infer) -> Result<()> { println!(); } } - Err(InferenceError::ContextFull) => { + Err(llm::InferenceError::ContextFull) => { log::warn!("Context window full, stopping inference.") } - Err(InferenceError::TokenizationFailed(err)) => { + Err(llm::InferenceError::TokenizationFailed(err)) => { log::error!("A tokenization-related failure occurred: {}", err); } - Err(InferenceError::UserCallback(_)) | Err(InferenceError::EndOfText) => { + Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => { unreachable!("cannot fail") } } @@ -105,7 +95,7 @@ fn infer(args: &cli_args::Infer) -> Result<()> { Ok(()) } -fn perplexity(args: &cli_args::Perplexity) -> Result<()> { +fn perplexity(args: &cli_args::Perplexity) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let inference_session_config = args.generate.inference_session_config(); let model = args.model_load.load(args.generate.use_gpu)?; @@ -125,10 +115,10 @@ fn perplexity(args: &cli_args::Perplexity) -> Result<()> { Ok(()) } -fn info(args: &cli_args::Info) -> Result<()> { +fn info(args: &cli_args::Info) -> eyre::Result<()> { struct InfoVisitor<'a>(&'a cli_args::Info); - impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> { - fn visit(&mut self) -> Result<()> { + impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> { + fn visit(&mut self) -> eyre::Result<()> { let args = self.0; let model_path = &args.model_and_tokenizer.model_path; @@ -178,7 +168,7 @@ fn info(args: &cli_args::Info) -> Result<()> { .visit(&mut InfoVisitor(args)) } -fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> { +fn prompt_tokens(args: &cli_args::PromptTokens) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let model = args.model_load.load(false)?; let toks = match model.tokenizer().tokenize(&prompt, false) { @@ -207,160 +197,12 @@ fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> { Ok(()) } -#[cfg(not(windows))] -fn force_newline_event_seq() -> KeyEvent { - KeyEvent(KeyCode::Enter, Modifiers::ALT) -} - -// On Windows, `SHIFT+ENTER` is the key sequence for forcing a newline. This is -// because `ALT+ENTER` typically maximizes the window. -#[cfg(windows)] -fn force_newline_event_seq() -> KeyEvent { - KeyEvent(KeyCode::Enter, Modifiers::SHIFT) -} - -fn repl(args: &cli_args::Repl) -> Result<()> { - interactive( - &args.generate, - &args.model_load, - false, - None, - args.prompt_file.contents()?.as_deref(), - ) -} - -fn chat(args: &cli_args::Chat) -> Result<()> { - interactive( - &args.generate, - &args.model_load, - true, - Some(std::fs::read_to_string(&args.prelude_prompt_file)?.as_str()), - Some(&args.message_prompt()?), - ) -} - -fn interactive( - generate: &cli_args::Generate, - model_load: &cli_args::ModelLoad, - chat_mode: bool, - mut initial_prompt_template: Option<&str>, - message_prompt_template: Option<&str>, -) -> Result<()> { - let inference_session_config = generate.inference_session_config(); - let model = model_load.load(generate.use_gpu)?; - - let recreate_session = - || snapshot::read_or_create_session(model.as_ref(), None, None, inference_session_config).0; - let mut session = recreate_session(); - - let parameters = generate.inference_parameters(model.eot_token_id()); - let mut rng = generate.rng(); - - fn session_ends_with_newline(session: &InferenceSession) -> bool { - session - .decoded_tokens() - .last() - .map(|t| *t == b'\n') - .unwrap_or(false) - } - - let mut infer = |session: &mut InferenceSession, mut prompt: String| { - // Add a newline to the beginning of the prompt if the last character in the session is not a newline - if !session_ends_with_newline(session) { - prompt.insert(0, '\n'); - } - - let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); - if let Err(InferenceError::ContextFull) = session.feed_prompt( - model.as_ref(), - ¶meters, - &prompt, - // OutputRequest - &mut Default::default(), - |_| Ok::<_, Infallible>(InferenceFeedback::Continue), - ) { - log::error!("Prompt exceeds context window length.") - }; - sp.clear(); - - session.infer::( - model.as_ref(), - &mut rng, - &llm::InferenceRequest { - prompt: "".into(), - parameters: ¶meters, - play_back_previous_tokens: false, - maximum_token_count: generate.num_predict, - }, - &mut Default::default(), - |r| match r { - InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(InferenceFeedback::Continue) - } - _ => Ok(InferenceFeedback::Continue), - }, - ) - }; - - let mut rl = rustyline::Editor::::new()?; - rl.set_helper(Some(LineContinuationValidator)); - rl.bind_sequence(force_newline_event_seq(), Cmd::Newline); - - loop { - let readline = rl.readline(">> "); - match readline { - Ok(raw_line) => { - let line = raw_line.replace("\\\n", "\n"); - - // Use the initial prompt template for the first inference, - // and then switch to the message prompt template afterwards - let mut prompt = initial_prompt_template - .take() - .or(message_prompt_template) - .map(|pf| process_prompt(pf, &line)) - .unwrap_or(line); - - // Add a newline to the end of the prompt if it doesn't end with one in chat mode - if chat_mode && !prompt.ends_with('\n') { - prompt.push('\n'); - } - - if let Err(err) = infer(&mut session, prompt) { - log::error!("{err}"); - break; - } - - if !session_ends_with_newline(&session) { - println!(); - } - - // Reload session in REPL mode - if message_prompt_template.is_none() { - session = recreate_session(); - } - } - Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { - break; - } - Err(err) => { - log::error!("{err}"); - break; - } - } - } - - Ok(()) -} - -fn quantize(args: &cli_args::Quantize) -> Result<()> { +fn quantize(args: &cli_args::Quantize) -> eyre::Result<()> { use llm::QuantizeProgress; struct QuantizeVisitor<'a>(&'a cli_args::Quantize); - impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { - fn visit(&mut self) -> Result<()> { + impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { + fn visit(&mut self) -> eyre::Result<()> { let args = self.0; let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?); @@ -418,33 +260,11 @@ fn quantize(args: &cli_args::Quantize) -> Result<()> { fn load_prompt_file_with_prompt( prompt_file: &cli_args::PromptFile, prompt: Option<&str>, -) -> Result { - Ok(if let Some(prompt_file) = prompt_file.contents()? { - if let Some(prompt) = prompt { - process_prompt(&prompt_file, prompt) - } else { - prompt_file - } - } else if let Some(prompt) = prompt { - prompt.to_owned() - } else { - bail!("No prompt or prompt file was provided. See --help"); +) -> eyre::Result { + Ok(match (prompt_file.contents()?, prompt) { + (Some(prompt_file), None) => prompt_file, + (None, Some(prompt)) => prompt.to_owned(), + (Some(prompt_file), Some(prompt)) => util::process_prompt(&prompt_file, prompt), + (None, None) => eyre::bail!("No prompt or prompt file was provided. See --help"), }) } - -#[derive(Completer, Helper, Highlighter, Hinter, Debug, Clone, Copy)] -struct LineContinuationValidator; - -impl Validator for LineContinuationValidator { - fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result { - if ctx.input().ends_with('\\') { - Ok(ValidationResult::Incomplete) - } else { - Ok(ValidationResult::Valid(None)) - } - } -} - -fn process_prompt(raw_prompt: &str, prompt: &str) -> String { - raw_prompt.replace("{{PROMPT}}", prompt) -} diff --git a/binaries/llm-cli/src/util.rs b/binaries/llm-cli/src/util.rs new file mode 100644 index 00000000..a925e005 --- /dev/null +++ b/binaries/llm-cli/src/util.rs @@ -0,0 +1,10 @@ +use std::io::Write; + +pub fn process_prompt(raw_prompt: &str, prompt: &str) -> String { + raw_prompt.replace("{{PROMPT}}", prompt) +} + +pub fn print_token(t: String) { + print!("{t}"); + std::io::stdout().flush().unwrap(); +} diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 2c4fcf6e..caff5e67 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -280,7 +280,7 @@ impl InferenceSession { } /// Feed a prompt to the model for this session. - pub fn feed_prompt<'a, E: std::error::Error + 'static, P: Into>>( + pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into>>( &mut self, model: &dyn Model, params: &InferenceParameters, @@ -407,7 +407,7 @@ impl InferenceSession { /// generated (specified by [InferenceRequest::maximum_token_count]). /// /// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token]. - pub fn infer( + pub fn infer( &mut self, model: &dyn Model, rng: &mut impl rand::Rng, @@ -438,14 +438,16 @@ impl InferenceSession { let parameters = request.parameters; // Feed the initial prompt through the transformer, to update its - // context window with new data. - self.feed_prompt( - model, - parameters, - request.prompt, - output_request, - feed_prompt_callback(&mut callback), - )?; + // context window with new data, if necessary. + if !request.prompt.is_empty() { + self.feed_prompt( + model, + parameters, + request.prompt, + output_request, + feed_prompt_callback(&mut callback), + )?; + } stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; @@ -661,7 +663,7 @@ pub enum InferenceError { EndOfText, #[error("the user-specified callback returned an error")] /// The user-specified callback returned an error. - UserCallback(Box), + UserCallback(Box), } #[derive(Error, Debug)] @@ -883,7 +885,7 @@ pub enum InferenceFeedback { /// Adapt an [InferenceResponse] callback so that it can be used in a call to /// [InferenceSession::feed_prompt]. -pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( +pub fn feed_prompt_callback<'a, E: std::error::Error + Send + Sync + 'static>( mut callback: impl FnMut(InferenceResponse) -> Result + 'a, ) -> impl FnMut(&[u8]) -> Result + 'a { let mut buffer = TokenUtf8Buffer::new(); @@ -892,3 +894,40 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( None => Ok(InferenceFeedback::Continue), } } + +/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated. +/// This callback is used in [InferenceSession::infer] in chat_mode. +pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>( + stop_sequence: &'a str, + mut callback: impl FnMut(String) + 'a, +) -> impl FnMut(InferenceResponse) -> Result + 'a { + let mut stop_sequence_buf = String::new(); + move |resp| match resp { + InferenceResponse::InferredToken(token) => { + // We've generated a token, so we need to check if it's contained in the stop sequence. + let mut buf = stop_sequence_buf.clone(); + buf.push_str(&token); + + if buf.starts_with(stop_sequence) { + // We've generated the stop sequence, so we're done. + // Note that this will contain the extra tokens that were generated after the stop sequence, + // which may affect generation. This is non-ideal, but it's the best we can do without + // modifying the model. + stop_sequence_buf.clear(); + return Ok(InferenceFeedback::Halt); + } else if stop_sequence.starts_with(&buf) { + // We've generated a prefix of the stop sequence, so we need to keep buffering. + stop_sequence_buf = buf; + return Ok(InferenceFeedback::Continue); + } + + // We've generated a token that isn't part of the stop sequence, so we can + // pass it to the callback. + stop_sequence_buf.clear(); + callback(buf); + Ok(InferenceFeedback::Continue) + } + InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), + _ => Ok(InferenceFeedback::Continue), + } +} diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index 1ec18d1c..3b2e80e6 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -23,9 +23,10 @@ pub use ggml; pub use ggml::Type as ElementType; pub use inference_session::{ - feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest, - InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, - InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, RewindError, SnapshotError, + conversation_inference_callback, feed_prompt_callback, GraphOutputs, InferenceError, + InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession, + InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, + ModelKVMemoryType, RewindError, SnapshotError, }; pub use loader::{ load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 25c71ba0..52cbee00 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -230,6 +230,14 @@ impl Prompt<'_> { } }) } + + /// Returns whether this prompt is empty. + pub fn is_empty(&self) -> bool { + match self { + Self::Text(text) => text.is_empty(), + Self::Tokens(tokens) => tokens.is_empty(), + } + } } impl<'a> Default for Prompt<'a> { fn default() -> Self { diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index 7cdeb1d1..75702bb1 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -1,4 +1,5 @@ use clap::Parser; +use llm_base::conversation_inference_callback; use rustyline::error::ReadlineError; use std::{convert::Infallible, io::Write, path::PathBuf}; @@ -62,7 +63,11 @@ fn main() { &mut Default::default(), llm::feed_prompt_callback(|resp| match resp { llm::InferenceResponse::PromptToken(t) - | llm::InferenceResponse::InferredToken(t) => print_token(t), + | llm::InferenceResponse::InferredToken(t) => { + print_token(t); + + Ok::(llm::InferenceFeedback::Continue) + } _ => Ok(llm::InferenceFeedback::Continue), }), ) @@ -72,7 +77,6 @@ fn main() { let mut rng = rand::thread_rng(); let mut res = llm::InferenceStats::default(); - let mut buf = String::new(); loop { println!(); @@ -81,7 +85,7 @@ fn main() { match readline { Ok(line) => { let stats = session - .infer( + .infer::( model.as_ref(), &mut rng, &llm::InferenceRequest { @@ -93,7 +97,7 @@ fn main() { maximum_token_count: None, }, &mut Default::default(), - inference_callback(String::from(user_name), &mut buf), + conversation_inference_callback(&format!("{character_name}:"), print_token), ) .unwrap_or_else(|e| panic!("{e}")); @@ -116,36 +120,7 @@ fn main() { println!("\n\nInference stats:\n{res}"); } -fn inference_callback( - stop_sequence: String, - buf: &mut String, -) -> impl FnMut(llm::InferenceResponse) -> Result + '_ { - move |resp| match resp { - llm::InferenceResponse::InferredToken(t) => { - let mut reverse_buf = buf.clone(); - reverse_buf.push_str(t.as_str()); - if stop_sequence.as_str().eq(reverse_buf.as_str()) { - buf.clear(); - return Ok(llm::InferenceFeedback::Halt); - } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { - buf.push_str(t.as_str()); - return Ok(llm::InferenceFeedback::Continue); - } - - if buf.is_empty() { - print_token(t) - } else { - print_token(reverse_buf) - } - } - llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt), - _ => Ok(llm::InferenceFeedback::Continue), - } -} - -fn print_token(t: String) -> Result { +fn print_token(t: String) { print!("{t}"); std::io::stdout().flush().unwrap(); - - Ok(llm::InferenceFeedback::Continue) } diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 35692951..96501221 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -77,14 +77,14 @@ use std::{ // Try not to expose too many GGML details here. // This is the "user-facing" API, and GGML may not always be our backend. pub use llm_base::{ - feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout, - quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters, - InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, - InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, - InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, - ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, - RewindError, Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, - Tokenizer, TokenizerSource, + conversation_inference_callback, feed_prompt_callback, ggml::format as ggml_format, load, + load_progress_callback_stdout, quantize, samplers, ElementType, FileType, FileTypeFormat, + FormatMagic, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, + InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, + InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, + LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, + Prompt, QuantizeError, QuantizeProgress, RewindError, Sampler, SnapshotError, TokenBias, + TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource, }; use serde::Serialize; diff --git a/utils/prompts/pygmalion-message.txt b/utils/prompts/pygmalion-message.txt new file mode 100644 index 00000000..6089c80f --- /dev/null +++ b/utils/prompts/pygmalion-message.txt @@ -0,0 +1 @@ +You: \ No newline at end of file diff --git a/utils/prompts/pygmalion.txt b/utils/prompts/pygmalion-prelude.txt similarity index 71% rename from utils/prompts/pygmalion.txt rename to utils/prompts/pygmalion-prelude.txt index 8b3cbadb..d6629953 100644 --- a/utils/prompts/pygmalion.txt +++ b/utils/prompts/pygmalion-prelude.txt @@ -1,4 +1,3 @@ Assistant's Persona: Assistant is a highly intelligent language model trained to comply with user requests. -Assistant: How may I help you? -You: {{PROMPT}} \ No newline at end of file +Assistant: How may I help you? \ No newline at end of file diff --git a/utils/prompts/vicuna-message.txt b/utils/prompts/vicuna-message.txt new file mode 100644 index 00000000..6b32a757 --- /dev/null +++ b/utils/prompts/vicuna-message.txt @@ -0,0 +1 @@ +User: \ No newline at end of file diff --git a/utils/prompts/vicuna.txt b/utils/prompts/vicuna-prelude.txt similarity index 76% rename from utils/prompts/vicuna.txt rename to utils/prompts/vicuna-prelude.txt index 1da54727..9c1bf219 100644 --- a/utils/prompts/vicuna.txt +++ b/utils/prompts/vicuna-prelude.txt @@ -1,4 +1,3 @@ A chat between a human ("User") and an AI assistant ("Assistant"). The assistant gives helpful, detailed, and polite answers to the human's questions. -Assistant: How may I help you? -User: {{PROMPT}} \ No newline at end of file +Assistant: How may I help you? \ No newline at end of file