From 43b4f0bea2dbd8a24af88ee125f49cfe0c1270cd Mon Sep 17 00:00:00 2001 From: averypelle Date: Sun, 9 Jul 2023 15:43:49 -0400 Subject: [PATCH 01/12] allow chat to halt new token generation --- binaries/llm-cli/src/main.rs | 46 +++++++++++++++++++++++++++++------- 1 file changed, 37 insertions(+), 9 deletions(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 9cd8cb84..f782fa51 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -256,6 +256,8 @@ fn interactive( let parameters = generate.inference_parameters(model.eot_token_id()); let mut rng = generate.rng(); + let mut buf = String::new(); + fn session_ends_with_newline(session: &InferenceSession) -> bool { session .decoded_tokens() @@ -264,6 +266,33 @@ fn interactive( .unwrap_or(false) } + fn inference_callback( + stop_sequence: String, + buf: &mut String, + ) -> impl FnMut(InferenceResponse) -> Result + '_ { + move |resp| match resp { + InferenceResponse::InferredToken(t) => { + let mut reverse_buf = buf.clone(); + reverse_buf.push_str(t.as_str()); + if stop_sequence.as_str().eq(reverse_buf.as_str()) { + buf.clear(); + return Ok(InferenceFeedback::Halt); + } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { + buf.push_str(t.as_str()); + return Ok(InferenceFeedback::Continue); + } + + if buf.is_empty() { + print_token(t) + } else { + print_token(reverse_buf) + } + } + InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), + _ => Ok(InferenceFeedback::Continue), + } + } + let mut infer = |session: &mut InferenceSession, mut prompt: String| { // Add a newline to the beginning of the prompt if the last character in the session is not a newline if !session_ends_with_newline(session) { @@ -293,15 +322,7 @@ fn interactive( maximum_token_count: generate.num_predict, }, &mut Default::default(), - |r| match r { - InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(InferenceFeedback::Continue) - } - _ => Ok(InferenceFeedback::Continue), - }, + inference_callback(String::from("User:"), &mut buf), ) }; @@ -448,3 +469,10 @@ impl Validator for LineContinuationValidator { fn process_prompt(raw_prompt: &str, prompt: &str) -> String { raw_prompt.replace("{{PROMPT}}", prompt) } + +fn print_token(t: String) -> Result { + print!("{t}"); + std::io::stdout().flush().unwrap(); + + Ok(llm::InferenceFeedback::Continue) +} From f87b4fa05476984b642db488b3aeb64f5936e02a Mon Sep 17 00:00:00 2001 From: averypelle Date: Sun, 9 Jul 2023 16:23:28 -0400 Subject: [PATCH 02/12] pull function out --- binaries/llm-cli/src/main.rs | 65 ++++++++++++++++++++---------------- 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index f782fa51..e3c9c1f8 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -256,6 +256,10 @@ fn interactive( let parameters = generate.inference_parameters(model.eot_token_id()); let mut rng = generate.rng(); + let stop_sequence = message_prompt_template + .map(|s| s.replace("{{PROMPT}}", "").trim().to_owned()) + .unwrap_or_default(); + let mut buf = String::new(); fn session_ends_with_newline(session: &InferenceSession) -> bool { @@ -266,33 +270,6 @@ fn interactive( .unwrap_or(false) } - fn inference_callback( - stop_sequence: String, - buf: &mut String, - ) -> impl FnMut(InferenceResponse) -> Result + '_ { - move |resp| match resp { - InferenceResponse::InferredToken(t) => { - let mut reverse_buf = buf.clone(); - reverse_buf.push_str(t.as_str()); - if stop_sequence.as_str().eq(reverse_buf.as_str()) { - buf.clear(); - return Ok(InferenceFeedback::Halt); - } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { - buf.push_str(t.as_str()); - return Ok(InferenceFeedback::Continue); - } - - if buf.is_empty() { - print_token(t) - } else { - print_token(reverse_buf) - } - } - InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), - _ => Ok(InferenceFeedback::Continue), - } - } - let mut infer = |session: &mut InferenceSession, mut prompt: String| { // Add a newline to the beginning of the prompt if the last character in the session is not a newline if !session_ends_with_newline(session) { @@ -322,7 +299,7 @@ fn interactive( maximum_token_count: generate.num_predict, }, &mut Default::default(), - inference_callback(String::from("User:"), &mut buf), + inference_callback(stop_sequence.clone(), chat_mode, &mut buf), ) }; @@ -470,6 +447,38 @@ fn process_prompt(raw_prompt: &str, prompt: &str) -> String { raw_prompt.replace("{{PROMPT}}", prompt) } +fn inference_callback( + stop_sequence: String, + chat_mode: bool, + buf: &mut String, +) -> impl FnMut(InferenceResponse) -> Result + '_ { + move |resp| match resp { + InferenceResponse::InferredToken(t) => { + if chat_mode { + let mut reverse_buf = buf.clone(); + reverse_buf.push_str(t.as_str()); + if stop_sequence.as_str().eq(reverse_buf.as_str()) { + buf.clear(); + return Ok(InferenceFeedback::Halt); + } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { + buf.push_str(t.as_str()); + return Ok(InferenceFeedback::Continue); + } + + if buf.is_empty() { + print_token(t) + } else { + print_token(reverse_buf) + } + } else { + print_token(t) + } + } + InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), + _ => Ok(InferenceFeedback::Continue), + } +} + fn print_token(t: String) -> Result { print!("{t}"); std::io::stdout().flush().unwrap(); From 38d8632df7e2e13a5a9a4b3708dd7616b81af3ff Mon Sep 17 00:00:00 2001 From: averypelle Date: Mon, 10 Jul 2023 12:13:19 -0400 Subject: [PATCH 03/12] fix: move inference callback to llm-base --- binaries/llm-cli/src/main.rs | 95 +++++++++++------------- crates/llm-base/src/inference_session.rs | 33 ++++++++ crates/llm-base/src/lib.rs | 7 +- crates/llm/examples/vicuna-chat.rs | 46 ++++-------- crates/llm/src/lib.rs | 16 ++-- 5 files changed, 101 insertions(+), 96 deletions(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index e3c9c1f8..14a19de2 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -7,7 +7,10 @@ use std::{ use clap::Parser; use cli_args::Args; use color_eyre::eyre::{bail, Context, ContextCompat, Result}; -use llm::{InferenceError, InferenceFeedback, InferenceResponse, InferenceSession}; +use llm::{ + conversation_inference_callback, InferenceError, InferenceFeedback, InferenceResponse, + InferenceSession, +}; use rustyline::{ error::ReadlineError, history::DefaultHistory, @@ -262,6 +265,11 @@ fn interactive( let mut buf = String::new(); + fn print_token(t: String) { + print!("{t}"); + std::io::stdout().flush().unwrap(); + } + fn session_ends_with_newline(session: &InferenceSession) -> bool { session .decoded_tokens() @@ -289,18 +297,40 @@ fn interactive( }; sp.clear(); - session.infer::( - model.as_ref(), - &mut rng, - &llm::InferenceRequest { - prompt: "".into(), - parameters: ¶meters, - play_back_previous_tokens: false, - maximum_token_count: generate.num_predict, - }, - &mut Default::default(), - inference_callback(stop_sequence.clone(), chat_mode, &mut buf), - ) + if chat_mode { + session.infer::( + model.as_ref(), + &mut rng, + &llm::InferenceRequest { + prompt: "".into(), + parameters: ¶meters, + play_back_previous_tokens: false, + maximum_token_count: generate.num_predict, + }, + &mut Default::default(), + conversation_inference_callback(stop_sequence.clone(), &mut buf, print_token), + ) + } else { + session.infer::( + model.as_ref(), + &mut rng, + &llm::InferenceRequest { + prompt: "".into(), + parameters: ¶meters, + play_back_previous_tokens: false, + maximum_token_count: generate.num_predict, + }, + &mut Default::default(), + |r| match r { + InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => { + print_token(t); + + Ok(InferenceFeedback::Continue) + } + _ => Ok(InferenceFeedback::Continue), + }, + ) + } }; let mut rl = rustyline::Editor::::new()?; @@ -446,42 +476,3 @@ impl Validator for LineContinuationValidator { fn process_prompt(raw_prompt: &str, prompt: &str) -> String { raw_prompt.replace("{{PROMPT}}", prompt) } - -fn inference_callback( - stop_sequence: String, - chat_mode: bool, - buf: &mut String, -) -> impl FnMut(InferenceResponse) -> Result + '_ { - move |resp| match resp { - InferenceResponse::InferredToken(t) => { - if chat_mode { - let mut reverse_buf = buf.clone(); - reverse_buf.push_str(t.as_str()); - if stop_sequence.as_str().eq(reverse_buf.as_str()) { - buf.clear(); - return Ok(InferenceFeedback::Halt); - } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { - buf.push_str(t.as_str()); - return Ok(InferenceFeedback::Continue); - } - - if buf.is_empty() { - print_token(t) - } else { - print_token(reverse_buf) - } - } else { - print_token(t) - } - } - InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), - _ => Ok(InferenceFeedback::Continue), - } -} - -fn print_token(t: String) -> Result { - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(llm::InferenceFeedback::Continue) -} diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 37861174..526b370a 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -853,3 +853,36 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( None => Ok(InferenceFeedback::Continue), } } + +/// Callback to be passed to [InferenceSession::infer] that will print the +/// token to stdout and will halt execution when the stop sequence is encountered. +/// Only to be used for chat mode. +pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( + stop_sequence: String, + buf: &'a mut String, + mut print_token: impl FnMut(String) + 'a, +) -> impl FnMut(InferenceResponse) -> Result + 'a { + move |resp| match resp { + InferenceResponse::InferredToken(t) => { + let mut reverse_buf = buf.clone(); + reverse_buf.push_str(t.as_str()); + if stop_sequence.as_str().eq(reverse_buf.as_str()) { + buf.clear(); + return Ok(InferenceFeedback::Halt); + } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { + buf.push_str(t.as_str()); + return Ok(InferenceFeedback::Continue); + } + + if buf.is_empty() { + print_token(t); + Ok(InferenceFeedback::Continue) + } else { + print_token(reverse_buf); + Ok(InferenceFeedback::Continue) + } + } + InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), + _ => Ok(InferenceFeedback::Continue), + } +} diff --git a/crates/llm-base/src/lib.rs b/crates/llm-base/src/lib.rs index d40a9077..c1ec11f2 100644 --- a/crates/llm-base/src/lib.rs +++ b/crates/llm-base/src/lib.rs @@ -23,9 +23,10 @@ pub use ggml; pub use ggml::Type as ElementType; pub use inference_session::{ - feed_prompt_callback, GraphOutputs, InferenceError, InferenceFeedback, InferenceRequest, - InferenceResponse, InferenceSession, InferenceSessionConfig, InferenceSnapshot, - InferenceSnapshotRef, InferenceStats, ModelKVMemoryType, SnapshotError, + conversation_inference_callback, feed_prompt_callback, GraphOutputs, InferenceError, + InferenceFeedback, InferenceRequest, InferenceResponse, InferenceSession, + InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, InferenceStats, + ModelKVMemoryType, SnapshotError, }; pub use loader::{ load, load_progress_callback_stdout, ContainerType, FileType, FileTypeFormat, FormatMagic, diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index 7cdeb1d1..a67810cc 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -1,4 +1,5 @@ use clap::Parser; +use llm_base::conversation_inference_callback; use rustyline::error::ReadlineError; use std::{convert::Infallible, io::Write, path::PathBuf}; @@ -62,7 +63,11 @@ fn main() { &mut Default::default(), llm::feed_prompt_callback(|resp| match resp { llm::InferenceResponse::PromptToken(t) - | llm::InferenceResponse::InferredToken(t) => print_token(t), + | llm::InferenceResponse::InferredToken(t) => { + print_token(t); + + Ok::(llm::InferenceFeedback::Continue) + } _ => Ok(llm::InferenceFeedback::Continue), }), ) @@ -81,7 +86,7 @@ fn main() { match readline { Ok(line) => { let stats = session - .infer( + .infer::( model.as_ref(), &mut rng, &llm::InferenceRequest { @@ -93,7 +98,11 @@ fn main() { maximum_token_count: None, }, &mut Default::default(), - inference_callback(String::from(user_name), &mut buf), + conversation_inference_callback( + String::from(user_name), + &mut buf, + print_token, + ), ) .unwrap_or_else(|e| panic!("{e}")); @@ -116,36 +125,7 @@ fn main() { println!("\n\nInference stats:\n{res}"); } -fn inference_callback( - stop_sequence: String, - buf: &mut String, -) -> impl FnMut(llm::InferenceResponse) -> Result + '_ { - move |resp| match resp { - llm::InferenceResponse::InferredToken(t) => { - let mut reverse_buf = buf.clone(); - reverse_buf.push_str(t.as_str()); - if stop_sequence.as_str().eq(reverse_buf.as_str()) { - buf.clear(); - return Ok(llm::InferenceFeedback::Halt); - } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { - buf.push_str(t.as_str()); - return Ok(llm::InferenceFeedback::Continue); - } - - if buf.is_empty() { - print_token(t) - } else { - print_token(reverse_buf) - } - } - llm::InferenceResponse::EotToken => Ok(llm::InferenceFeedback::Halt), - _ => Ok(llm::InferenceFeedback::Continue), - } -} - -fn print_token(t: String) -> Result { +fn print_token(t: String) { print!("{t}"); std::io::stdout().flush().unwrap(); - - Ok(llm::InferenceFeedback::Continue) } diff --git a/crates/llm/src/lib.rs b/crates/llm/src/lib.rs index 2be90739..b971bd90 100644 --- a/crates/llm/src/lib.rs +++ b/crates/llm/src/lib.rs @@ -77,14 +77,14 @@ use std::{ // Try not to expose too many GGML details here. // This is the "user-facing" API, and GGML may not always be our backend. pub use llm_base::{ - feed_prompt_callback, ggml::format as ggml_format, load, load_progress_callback_stdout, - quantize, samplers, ElementType, FileType, FileTypeFormat, FormatMagic, Hyperparameters, - InferenceError, InferenceFeedback, InferenceParameters, InferenceRequest, InferenceResponse, - InferenceSession, InferenceSessionConfig, InferenceSnapshot, InferenceSnapshotRef, - InferenceStats, InvalidTokenBias, KnownModel, LoadError, LoadProgress, Loader, Model, - ModelKVMemoryType, ModelParameters, OutputRequest, Prompt, QuantizeError, QuantizeProgress, - Sampler, SnapshotError, TokenBias, TokenId, TokenUtf8Buffer, TokenizationError, Tokenizer, - TokenizerSource, + conversation_inference_callback, feed_prompt_callback, ggml::format as ggml_format, load, + load_progress_callback_stdout, quantize, samplers, ElementType, FileType, FileTypeFormat, + FormatMagic, Hyperparameters, InferenceError, InferenceFeedback, InferenceParameters, + InferenceRequest, InferenceResponse, InferenceSession, InferenceSessionConfig, + InferenceSnapshot, InferenceSnapshotRef, InferenceStats, InvalidTokenBias, KnownModel, + LoadError, LoadProgress, Loader, Model, ModelKVMemoryType, ModelParameters, OutputRequest, + Prompt, QuantizeError, QuantizeProgress, Sampler, SnapshotError, TokenBias, TokenId, + TokenUtf8Buffer, TokenizationError, Tokenizer, TokenizerSource, }; use serde::Serialize; From 41bf37a85f54837a6fb3aca4f7ba5f1de6ea77c7 Mon Sep 17 00:00:00 2001 From: averypelle Date: Wed, 12 Jul 2023 10:47:05 -0400 Subject: [PATCH 04/12] clarify function comment --- crates/llm-base/src/inference_session.rs | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 526b370a..ee34a1c5 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -854,9 +854,8 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( } } -/// Callback to be passed to [InferenceSession::infer] that will print the -/// token to stdout and will halt execution when the stop sequence is encountered. -/// Only to be used for chat mode. +/// An [InferenceResponse] callback that will halt inference when a stop_sequence is generated. +/// This callback is used in [InferenceSession::infer] in chat_mode. pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( stop_sequence: String, buf: &'a mut String, From 39e45db21a19ba48705817e6e77229cd220f9497 Mon Sep 17 00:00:00 2001 From: averypelle Date: Wed, 12 Jul 2023 11:39:53 -0400 Subject: [PATCH 05/12] change args to message-prompt-prefix --- binaries/llm-cli/src/cli_args.rs | 60 ++++++++++++++++---------------- binaries/llm-cli/src/main.rs | 26 +++++++------- 2 files changed, 44 insertions(+), 42 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 022957e6..211f7451 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -179,51 +179,51 @@ pub struct Chat { #[arg(long, short = 'f')] pub prelude_prompt_file: PathBuf, - /// The per-message prompt to use. + /// The per-message prefix to be prepended to the user's message. /// - /// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the - /// user's message. + /// The `{{PROMPT}}` will automatically be appended to this prefix. #[arg(long, short = 'p')] - pub message_prompt: Option, + pub message_prompt_prefix: Option, - /// The file to read the per-message prompt from. + /// The file containing the per-message prefix to be prepended to the user's message. /// - /// Must contain a `{{PROMPT}}` placeholder, which will be replaced with the - /// user's message. + /// The `{{PROMPT}}` will automatically be appended to this prefix. #[arg(long, short = 'q')] - pub message_prompt_file: Option, + pub message_prompt_prefix_file: Option, #[command(flatten)] pub generate: Generate, } impl Chat { - pub fn message_prompt(&self) -> eyre::Result { - if self.message_prompt.is_some() && self.message_prompt_file.is_some() { - eyre::bail!("Cannot specify both --message-prompt and --message-prompt-file") + pub fn message_prompt_prefix(&self) -> eyre::Result { + if self.message_prompt_prefix.is_some() && self.message_prompt_prefix_file.is_some() { + eyre::bail!( + "Cannot specify both --message-prompt-prefix and --message-prompt-prefix-file" + ) } - if let Some(message_prompt_file) = &self.message_prompt_file { - read_prompt_file(message_prompt_file).and_then(|prompt| { - prompt - .contains("{{PROMPT}}") - .then_some(prompt) - .ok_or_else(|| { - eyre::eyre!( - "Message prompt file must contain a `{{{{PROMPT}}}}` placeholder, but it does not" - ) - }) + if let Some(message_prompt_prefix_file) = &self.message_prompt_prefix_file { + read_prompt_file(message_prompt_prefix_file).and_then(|prompt| { + if prompt.contains("{{PROMPT}}") { + eyre::bail!( + "Message prompt file must not contain a `{{{{PROMPT}}}}` placeholder. The `{{PROMPT}}` will be automatically appended to the prefix." + ) + } else { + Ok(prompt) + } }) - } else if let Some(message_prompt) = &self.message_prompt { - message_prompt - .contains("{{PROMPT}}") - .then(|| message_prompt.clone()) - .ok_or_else(|| { - eyre::eyre!( - "Message prompt must contain a `{{{{PROMPT}}}}` placeholder, but it does not" + } else if let Some(message_prompt_prefix) = &self.message_prompt_prefix { + if message_prompt_prefix.contains("{{PROMPT}}") { + eyre::bail!( + "Message prompt file must not contain a `{{{{PROMPT}}}}` placeholder. The `{{PROMPT}}` will be automatically appended to the prefix." ) - }) + } else { + Ok(message_prompt_prefix.clone()) + } } else { - eyre::bail!("Must specify either --message-prompt or --message-prompt-file") + eyre::bail!( + "Must specify either --message-prompt-prefix or --message-prompt-prefix-file" + ) } } } diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 14a19de2..73ccd636 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -238,7 +238,7 @@ fn chat(args: &cli_args::Chat) -> Result<()> { &args.model_load, true, Some(std::fs::read_to_string(&args.prelude_prompt_file)?.as_str()), - Some(&args.message_prompt()?), + Some(&args.message_prompt_prefix()?), ) } @@ -247,7 +247,7 @@ fn interactive( model_load: &cli_args::ModelLoad, chat_mode: bool, mut initial_prompt_template: Option<&str>, - message_prompt_template: Option<&str>, + message_prompt_prefix: Option<&str>, ) -> Result<()> { let inference_session_config = generate.inference_session_config(); let model = model_load.load(generate.use_gpu)?; @@ -259,10 +259,6 @@ fn interactive( let parameters = generate.inference_parameters(model.eot_token_id()); let mut rng = generate.rng(); - let stop_sequence = message_prompt_template - .map(|s| s.replace("{{PROMPT}}", "").trim().to_owned()) - .unwrap_or_default(); - let mut buf = String::new(); fn print_token(t: String) { @@ -298,6 +294,8 @@ fn interactive( sp.clear(); if chat_mode { + let stop_sequence = message_prompt_prefix.unwrap_or_default().to_owned(); + session.infer::( model.as_ref(), &mut rng, @@ -308,7 +306,7 @@ fn interactive( maximum_token_count: generate.num_predict, }, &mut Default::default(), - conversation_inference_callback(stop_sequence.clone(), &mut buf, print_token), + conversation_inference_callback(stop_sequence, &mut buf, print_token), ) } else { session.infer::( @@ -344,12 +342,16 @@ fn interactive( let line = raw_line.replace("\\\n", "\n"); // Use the initial prompt template for the first inference, - // and then switch to the message prompt template afterwards + // and then switch to the message prompt prefix afterwards. + // Only the initial prompt template needs to be formatted. let mut prompt = initial_prompt_template .take() - .or(message_prompt_template) - .map(|pf| process_prompt(pf, &line)) - .unwrap_or(line); + .map(|template| process_prompt(template, &line)) + .unwrap_or_else(|| { + message_prompt_prefix + .map(|prefix| format!("{} {}", prefix, line)) + .unwrap_or_else(|| line) + }); // Add a newline to the end of the prompt if it doesn't end with one in chat mode if chat_mode && !prompt.ends_with('\n') { @@ -366,7 +368,7 @@ fn interactive( } // Reload session in REPL mode - if message_prompt_template.is_none() { + if !chat_mode { session = recreate_session(); } } From b4efde7b9dc3bfa9f248812e4b6ca626cd809938 Mon Sep 17 00:00:00 2001 From: averypelle Date: Wed, 12 Jul 2023 12:07:15 -0400 Subject: [PATCH 06/12] update comment --- binaries/llm-cli/src/main.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 73ccd636..b27af930 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -343,7 +343,6 @@ fn interactive( // Use the initial prompt template for the first inference, // and then switch to the message prompt prefix afterwards. - // Only the initial prompt template needs to be formatted. let mut prompt = initial_prompt_template .take() .map(|template| process_prompt(template, &line)) From a0ad8b40326085f9d59794c59409a156cb63d422 Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 12 Jul 2023 23:23:28 +0200 Subject: [PATCH 07/12] fix(cli): don't insert newline at start of chat --- binaries/llm-cli/src/main.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index b27af930..0e7520e7 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -271,7 +271,7 @@ fn interactive( .decoded_tokens() .last() .map(|t| *t == b'\n') - .unwrap_or(false) + .unwrap_or(true) } let mut infer = |session: &mut InferenceSession, mut prompt: String| { From 138263a42c96aa52db2a8973d6282e2859b8c754 Mon Sep 17 00:00:00 2001 From: Philpax Date: Wed, 12 Jul 2023 23:27:17 +0200 Subject: [PATCH 08/12] fix(llm): clarify conversation_inference_callback --- binaries/llm-cli/src/main.rs | 16 +++++----- crates/llm-base/src/inference_session.rs | 39 +++++++++++++----------- crates/llm/examples/vicuna-chat.rs | 7 +---- 3 files changed, 30 insertions(+), 32 deletions(-) diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index 0e7520e7..b1e13d90 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -259,13 +259,6 @@ fn interactive( let parameters = generate.inference_parameters(model.eot_token_id()); let mut rng = generate.rng(); - let mut buf = String::new(); - - fn print_token(t: String) { - print!("{t}"); - std::io::stdout().flush().unwrap(); - } - fn session_ends_with_newline(session: &InferenceSession) -> bool { session .decoded_tokens() @@ -293,6 +286,11 @@ fn interactive( }; sp.clear(); + fn print_token(t: String) { + print!("{t}"); + std::io::stdout().flush().unwrap(); + } + if chat_mode { let stop_sequence = message_prompt_prefix.unwrap_or_default().to_owned(); @@ -306,7 +304,7 @@ fn interactive( maximum_token_count: generate.num_predict, }, &mut Default::default(), - conversation_inference_callback(stop_sequence, &mut buf, print_token), + conversation_inference_callback(stop_sequence, print_token), ) } else { session.infer::( @@ -348,7 +346,7 @@ fn interactive( .map(|template| process_prompt(template, &line)) .unwrap_or_else(|| { message_prompt_prefix - .map(|prefix| format!("{} {}", prefix, line)) + .map(|prefix| format!("{}{}", prefix, line)) .unwrap_or_else(|| line) }); diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index 5db4b71e..e310ffa4 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -893,32 +893,37 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( } } -/// An [InferenceResponse] callback that will halt inference when a stop_sequence is generated. +/// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated. /// This callback is used in [InferenceSession::infer] in chat_mode. pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( stop_sequence: String, - buf: &'a mut String, - mut print_token: impl FnMut(String) + 'a, + mut callback: impl FnMut(String) + 'a, ) -> impl FnMut(InferenceResponse) -> Result + 'a { + let mut stop_sequence_buf = String::new(); move |resp| match resp { - InferenceResponse::InferredToken(t) => { - let mut reverse_buf = buf.clone(); - reverse_buf.push_str(t.as_str()); - if stop_sequence.as_str().eq(reverse_buf.as_str()) { - buf.clear(); + InferenceResponse::InferredToken(token) => { + // We've generated a token, so we need to check if it's contained in the stop sequence. + let mut buf = stop_sequence_buf.clone(); + buf.push_str(&token); + + if buf.starts_with(&stop_sequence) { + // We've generated the stop sequence, so we're done. + // Note that this will contain the extra tokens that were generated after the stop sequence, + // which may affect generation. This is non-ideal, but it's the best we can do without + // modifying the model. + stop_sequence_buf.clear(); return Ok(InferenceFeedback::Halt); - } else if stop_sequence.as_str().starts_with(reverse_buf.as_str()) { - buf.push_str(t.as_str()); + } else if stop_sequence.starts_with(&buf) { + // We've generated a prefix of the stop sequence, so we need to keep buffering. + stop_sequence_buf = buf; return Ok(InferenceFeedback::Continue); } - if buf.is_empty() { - print_token(t); - Ok(InferenceFeedback::Continue) - } else { - print_token(reverse_buf); - Ok(InferenceFeedback::Continue) - } + // We've generated a token that isn't part of the stop sequence, so we can + // pass it to the callback. + stop_sequence_buf.clear(); + callback(buf); + Ok(InferenceFeedback::Continue) } InferenceResponse::EotToken => Ok(InferenceFeedback::Halt), _ => Ok(InferenceFeedback::Continue), diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index a67810cc..ecf73de4 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -77,7 +77,6 @@ fn main() { let mut rng = rand::thread_rng(); let mut res = llm::InferenceStats::default(); - let mut buf = String::new(); loop { println!(); @@ -98,11 +97,7 @@ fn main() { maximum_token_count: None, }, &mut Default::default(), - conversation_inference_callback( - String::from(user_name), - &mut buf, - print_token, - ), + conversation_inference_callback(format!("{character_name}:"), print_token), ) .unwrap_or_else(|e| panic!("{e}")); From 87025932b49ffdd4b96a2ad0ecaac2636ed8c09a Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Jul 2023 01:29:12 +0200 Subject: [PATCH 09/12] refactor(cli): simplify message_prompt_prefix --- binaries/llm-cli/src/cli_args.rs | 48 ++++++++++++++++---------------- 1 file changed, 24 insertions(+), 24 deletions(-) diff --git a/binaries/llm-cli/src/cli_args.rs b/binaries/llm-cli/src/cli_args.rs index 211f7451..ca6cf9e4 100644 --- a/binaries/llm-cli/src/cli_args.rs +++ b/binaries/llm-cli/src/cli_args.rs @@ -196,34 +196,34 @@ pub struct Chat { } impl Chat { pub fn message_prompt_prefix(&self) -> eyre::Result { - if self.message_prompt_prefix.is_some() && self.message_prompt_prefix_file.is_some() { - eyre::bail!( + const MESSAGE_PROMPT_PREFIX_ERROR: &str = concat!( + "Message prompt prefix must not contain a `{{PROMPT}}` placeholder. ", + "The prompt will be automatically appended to the prefix." + ); + + match ( + &self.message_prompt_prefix, + &self.message_prompt_prefix_file, + ) { + (None, None) => eyre::bail!( + "Must specify either --message-prompt-prefix or --message-prompt-prefix-file" + ), + (Some(_), Some(_)) => eyre::bail!( "Cannot specify both --message-prompt-prefix and --message-prompt-prefix-file" - ) - } - - if let Some(message_prompt_prefix_file) = &self.message_prompt_prefix_file { - read_prompt_file(message_prompt_prefix_file).and_then(|prompt| { - if prompt.contains("{{PROMPT}}") { - eyre::bail!( - "Message prompt file must not contain a `{{{{PROMPT}}}}` placeholder. The `{{PROMPT}}` will be automatically appended to the prefix." - ) - } else { - Ok(prompt) + ), + (Some(message_prompt_prefix), None) => { + if message_prompt_prefix.contains("{{PROMPT}}") { + eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}"); } - }) - } else if let Some(message_prompt_prefix) = &self.message_prompt_prefix { - if message_prompt_prefix.contains("{{PROMPT}}") { - eyre::bail!( - "Message prompt file must not contain a `{{{{PROMPT}}}}` placeholder. The `{{PROMPT}}` will be automatically appended to the prefix." - ) - } else { Ok(message_prompt_prefix.clone()) } - } else { - eyre::bail!( - "Must specify either --message-prompt-prefix or --message-prompt-prefix-file" - ) + (None, Some(message_prompt_prefix_file)) => { + let prompt = read_prompt_file(message_prompt_prefix_file)?; + if prompt.contains("{{PROMPT}}") { + eyre::bail!("{MESSAGE_PROMPT_PREFIX_ERROR}"); + } + Ok(prompt) + } } } } From 710f3c2d8ec4318b15b2d46ae35bea17a7347025 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Jul 2023 01:30:39 +0200 Subject: [PATCH 10/12] fix(llm): only feed prompt if not empty --- crates/llm-base/src/inference_session.rs | 4 +++- crates/llm-base/src/tokenizer/mod.rs | 8 ++++++++ 2 files changed, 11 insertions(+), 1 deletion(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index e310ffa4..a71a57fd 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -438,7 +438,8 @@ impl InferenceSession { let parameters = request.parameters; // Feed the initial prompt through the transformer, to update its - // context window with new data. + // context window with new data, if necessary. + if !request.prompt.is_empty() { self.feed_prompt( model, parameters, @@ -446,6 +447,7 @@ impl InferenceSession { output_request, feed_prompt_callback(&mut callback), )?; + } stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; diff --git a/crates/llm-base/src/tokenizer/mod.rs b/crates/llm-base/src/tokenizer/mod.rs index 25c71ba0..52cbee00 100644 --- a/crates/llm-base/src/tokenizer/mod.rs +++ b/crates/llm-base/src/tokenizer/mod.rs @@ -230,6 +230,14 @@ impl Prompt<'_> { } }) } + + /// Returns whether this prompt is empty. + pub fn is_empty(&self) -> bool { + match self { + Self::Text(text) => text.is_empty(), + Self::Tokens(tokens) => tokens.is_empty(), + } + } } impl<'a> Default for Prompt<'a> { fn default() -> Self { From 74d2d67aefa34684920c993cc877b989a91f3da6 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Jul 2023 01:31:52 +0200 Subject: [PATCH 11/12] fix(llm): require errors to be Send+Sync --- crates/llm-base/src/inference_session.rs | 28 ++++++++++++------------ 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/crates/llm-base/src/inference_session.rs b/crates/llm-base/src/inference_session.rs index a71a57fd..caff5e67 100644 --- a/crates/llm-base/src/inference_session.rs +++ b/crates/llm-base/src/inference_session.rs @@ -280,7 +280,7 @@ impl InferenceSession { } /// Feed a prompt to the model for this session. - pub fn feed_prompt<'a, E: std::error::Error + 'static, P: Into>>( + pub fn feed_prompt<'a, E: std::error::Error + Send + Sync + 'static, P: Into>>( &mut self, model: &dyn Model, params: &InferenceParameters, @@ -407,7 +407,7 @@ impl InferenceSession { /// generated (specified by [InferenceRequest::maximum_token_count]). /// /// This is a wrapper around [Self::feed_prompt] and [Self::infer_next_token]. - pub fn infer( + pub fn infer( &mut self, model: &dyn Model, rng: &mut impl rand::Rng, @@ -440,13 +440,13 @@ impl InferenceSession { // Feed the initial prompt through the transformer, to update its // context window with new data, if necessary. if !request.prompt.is_empty() { - self.feed_prompt( - model, - parameters, - request.prompt, - output_request, - feed_prompt_callback(&mut callback), - )?; + self.feed_prompt( + model, + parameters, + request.prompt, + output_request, + feed_prompt_callback(&mut callback), + )?; } stats.feed_prompt_duration = start_at.elapsed().unwrap(); stats.prompt_tokens = self.n_past; @@ -663,7 +663,7 @@ pub enum InferenceError { EndOfText, #[error("the user-specified callback returned an error")] /// The user-specified callback returned an error. - UserCallback(Box), + UserCallback(Box), } #[derive(Error, Debug)] @@ -885,7 +885,7 @@ pub enum InferenceFeedback { /// Adapt an [InferenceResponse] callback so that it can be used in a call to /// [InferenceSession::feed_prompt]. -pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( +pub fn feed_prompt_callback<'a, E: std::error::Error + Send + Sync + 'static>( mut callback: impl FnMut(InferenceResponse) -> Result + 'a, ) -> impl FnMut(&[u8]) -> Result + 'a { let mut buffer = TokenUtf8Buffer::new(); @@ -897,8 +897,8 @@ pub fn feed_prompt_callback<'a, E: std::error::Error + 'static>( /// An [InferenceResponse] callback that will halt inference when a `stop_sequence` is generated. /// This callback is used in [InferenceSession::infer] in chat_mode. -pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( - stop_sequence: String, +pub fn conversation_inference_callback<'a, E: std::error::Error + Send + Sync + 'static>( + stop_sequence: &'a str, mut callback: impl FnMut(String) + 'a, ) -> impl FnMut(InferenceResponse) -> Result + 'a { let mut stop_sequence_buf = String::new(); @@ -908,7 +908,7 @@ pub fn conversation_inference_callback<'a, E: std::error::Error + 'static>( let mut buf = stop_sequence_buf.clone(); buf.push_str(&token); - if buf.starts_with(&stop_sequence) { + if buf.starts_with(stop_sequence) { // We've generated the stop sequence, so we're done. // Note that this will contain the extra tokens that were generated after the stop sequence, // which may affect generation. This is non-ideal, but it's the best we can do without From 34a8c6804b4a3221753993be079dfdf4a9b3f041 Mon Sep 17 00:00:00 2001 From: Philpax Date: Thu, 13 Jul 2023 01:42:01 +0200 Subject: [PATCH 12/12] feat(cli): rewrite interactive... again The previous abstraction made it hard to reason about what each codepath would do. To resolve this, I've split the code up and now have separate functions entirely that share code. --- binaries/llm-cli/src/interactive.rs | 223 +++++++++++++++ binaries/llm-cli/src/main.rs | 269 ++---------------- binaries/llm-cli/src/util.rs | 10 + crates/llm/examples/vicuna-chat.rs | 2 +- utils/prompts/pygmalion-message.txt | 1 + .../{pygmalion.txt => pygmalion-prelude.txt} | 3 +- utils/prompts/vicuna-message.txt | 1 + .../{vicuna.txt => vicuna-prelude.txt} | 3 +- 8 files changed, 269 insertions(+), 243 deletions(-) create mode 100644 binaries/llm-cli/src/interactive.rs create mode 100644 binaries/llm-cli/src/util.rs create mode 100644 utils/prompts/pygmalion-message.txt rename utils/prompts/{pygmalion.txt => pygmalion-prelude.txt} (71%) create mode 100644 utils/prompts/vicuna-message.txt rename utils/prompts/{vicuna.txt => vicuna-prelude.txt} (76%) diff --git a/binaries/llm-cli/src/interactive.rs b/binaries/llm-cli/src/interactive.rs new file mode 100644 index 00000000..8e5c71c4 --- /dev/null +++ b/binaries/llm-cli/src/interactive.rs @@ -0,0 +1,223 @@ +use std::convert::Infallible; + +use color_eyre::eyre; +use rustyline::{ + error::ReadlineError, + history::DefaultHistory, + validate::{ValidationContext, ValidationResult, Validator}, + Cmd, Completer, Helper, Highlighter, Hinter, KeyCode, KeyEvent, Modifiers, +}; + +use crate::{ + cli_args::{Chat, Repl}, + snapshot, util, +}; + +pub fn repl( + Repl { + generate, + model_load, + prompt_file, + }: &Repl, +) -> eyre::Result<()> { + let (inference_session_config, parameters, model, mut rng) = + initialize_common_state(generate, model_load)?; + + let template = prompt_file.contents()?; + + let model = model.as_ref(); + let mut session = create_session(model, inference_session_config); + readline_loop(|raw_line| { + let line = raw_line.replace("\\\n", "\n"); + + let prompt = template + .as_deref() + .map(|template| util::process_prompt(template, &line)) + .unwrap_or(line); + feed_prompt_with_spinner(model, &mut session, ¶meters, prompt)?; + + session.infer::( + model, + &mut rng, + &llm::InferenceRequest { + prompt: "".into(), + parameters: ¶meters, + play_back_previous_tokens: false, + maximum_token_count: generate.num_predict, + }, + &mut Default::default(), + |r| { + if let llm::InferenceResponse::InferredToken(t) = r { + util::print_token(t); + } + Ok(llm::InferenceFeedback::Continue) + }, + )?; + + if !session_ends_with_newline(&session) { + println!(); + } + session = create_session(model, inference_session_config); + + Ok(()) + }) +} + +pub fn chat(args: &Chat) -> eyre::Result<()> { + let Chat { + model_load, + prelude_prompt_file, + generate, + .. + } = args; + + let (inference_session_config, parameters, model, mut rng) = + initialize_common_state(generate, model_load)?; + + let prelude_prompt = std::fs::read_to_string(prelude_prompt_file)?; + let message_prompt_prefix = args.message_prompt_prefix()?; + + let model = model.as_ref(); + let mut session = create_session(model, inference_session_config); + feed_prompt_with_spinner(model, &mut session, ¶meters, prelude_prompt)?; + + readline_loop(|raw_line| { + let prompt = { + let line = raw_line.replace("\\\n", "\n"); + let mut prompt = format!("{message_prompt_prefix}{line}"); + // Add a newline to the end of the prompt if it doesn't end with one + if !prompt.ends_with('\n') { + prompt.push('\n'); + } + prompt + }; + + session.infer::( + model, + &mut rng, + &llm::InferenceRequest { + prompt: (&prompt).into(), + parameters: ¶meters, + play_back_previous_tokens: false, + maximum_token_count: generate.num_predict, + }, + &mut Default::default(), + llm::conversation_inference_callback(&message_prompt_prefix, util::print_token), + )?; + + if !session_ends_with_newline(&session) { + println!(); + } + + Ok(()) + }) +} + +fn initialize_common_state( + generate: &crate::cli_args::Generate, + model_load: &crate::cli_args::ModelLoad, +) -> eyre::Result<( + llm::InferenceSessionConfig, + llm::InferenceParameters, + Box, + rand::rngs::StdRng, +)> { + let model = model_load.load(generate.use_gpu)?; + Ok(( + generate.inference_session_config(), + generate.inference_parameters(model.eot_token_id()), + model, + generate.rng(), + )) +} + +fn feed_prompt_with_spinner( + model: &dyn llm::Model, + session: &mut llm::InferenceSession, + parameters: &llm::InferenceParameters, + mut prompt: String, +) -> eyre::Result<()> { + // Add a newline to the beginning of the prompt if the last character in the session is not a newline + if !session_ends_with_newline(session) { + prompt.insert(0, '\n'); + } + + let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); + let result = session.feed_prompt( + model, + parameters, + &prompt, + // OutputRequest + &mut Default::default(), + |_| Ok::<_, Infallible>(llm::InferenceFeedback::Continue), + ); + sp.clear(); + + Ok(result?) +} + +fn create_session( + model: &dyn llm::Model, + inference_session_config: llm::InferenceSessionConfig, +) -> llm::InferenceSession { + snapshot::read_or_create_session(model, None, None, inference_session_config).0 +} + +fn session_ends_with_newline(session: &llm::InferenceSession) -> bool { + session + .decoded_tokens() + .last() + .map(|t| *t == b'\n') + .unwrap_or(true) +} + +fn readline_loop(mut body: impl FnMut(String) -> eyre::Result<()>) -> eyre::Result<()> { + let mut rl = rustyline::Editor::::new()?; + rl.set_helper(Some(LineContinuationValidator)); + rl.bind_sequence(force_newline_event_seq(), Cmd::Newline); + + loop { + match rl.readline(">> ") { + Ok(raw_line) => { + if let Err(err) = body(raw_line) { + log::error!("{err}"); + break; + } + } + Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { + break; + } + Err(err) => { + log::error!("{err}"); + break; + } + } + } + + Ok(()) +} + +#[cfg(not(windows))] +fn force_newline_event_seq() -> KeyEvent { + KeyEvent(KeyCode::Enter, Modifiers::ALT) +} + +// On Windows, `SHIFT+ENTER` is the key sequence for forcing a newline. This is +// because `ALT+ENTER` typically maximizes the window. +#[cfg(windows)] +fn force_newline_event_seq() -> KeyEvent { + KeyEvent(KeyCode::Enter, Modifiers::SHIFT) +} + +#[derive(Completer, Helper, Highlighter, Hinter, Debug, Clone, Copy)] +struct LineContinuationValidator; + +impl Validator for LineContinuationValidator { + fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result { + if ctx.input().ends_with('\\') { + Ok(ValidationResult::Incomplete) + } else { + Ok(ValidationResult::Valid(None)) + } + } +} diff --git a/binaries/llm-cli/src/main.rs b/binaries/llm-cli/src/main.rs index b1e13d90..1c55824f 100644 --- a/binaries/llm-cli/src/main.rs +++ b/binaries/llm-cli/src/main.rs @@ -1,27 +1,19 @@ use std::{ convert::Infallible, fs::File, - io::{BufReader, BufWriter, Write}, + io::{BufReader, BufWriter}, }; use clap::Parser; use cli_args::Args; -use color_eyre::eyre::{bail, Context, ContextCompat, Result}; -use llm::{ - conversation_inference_callback, InferenceError, InferenceFeedback, InferenceResponse, - InferenceSession, -}; -use rustyline::{ - error::ReadlineError, - history::DefaultHistory, - validate::{ValidationContext, ValidationResult, Validator}, - Cmd, Completer, Helper, Highlighter, Hinter, KeyCode, KeyEvent, Modifiers, -}; +use color_eyre::eyre::{self, Context, ContextCompat}; mod cli_args; +mod interactive; mod snapshot; +mod util; -fn main() -> Result<()> { +fn main() -> eyre::Result<()> { env_logger::builder() .filter_level(log::LevelFilter::Info) .parse_default_env() @@ -34,13 +26,13 @@ fn main() -> Result<()> { Args::Perplexity(args) => perplexity(&args), Args::Info(args) => info(&args), Args::PromptTokens(args) => prompt_tokens(&args), - Args::Repl(args) => repl(&args), - Args::Chat(args) => chat(&args), + Args::Repl(args) => interactive::repl(&args), + Args::Chat(args) => interactive::chat(&args), Args::Quantize(args) => quantize(&args), } } -fn infer(args: &cli_args::Infer) -> Result<()> { +fn infer(args: &cli_args::Infer) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let inference_session_config = args.generate.inference_session_config(); let model = args.model_load.load(args.generate.use_gpu)?; @@ -65,18 +57,13 @@ fn infer(args: &cli_args::Infer) -> Result<()> { }, // OutputRequest &mut Default::default(), - |r| match &r { - InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => { - if matches!(&r, InferenceResponse::PromptToken(_)) && args.hide_prompt { - return Ok(InferenceFeedback::Continue); - } - - print!("{t}"); - std::io::stdout().flush().unwrap(); - - Ok(InferenceFeedback::Continue) + |r| { + match r { + llm::InferenceResponse::PromptToken(t) if !args.hide_prompt => util::print_token(t), + llm::InferenceResponse::InferredToken(t) => util::print_token(t), + _ => {} } - _ => Ok(InferenceFeedback::Continue), + Ok(llm::InferenceFeedback::Continue) }, ); println!(); @@ -89,13 +76,13 @@ fn infer(args: &cli_args::Infer) -> Result<()> { println!(); } } - Err(InferenceError::ContextFull) => { + Err(llm::InferenceError::ContextFull) => { log::warn!("Context window full, stopping inference.") } - Err(InferenceError::TokenizationFailed(err)) => { + Err(llm::InferenceError::TokenizationFailed(err)) => { log::error!("A tokenization-related failure occurred: {}", err); } - Err(InferenceError::UserCallback(_)) | Err(InferenceError::EndOfText) => { + Err(llm::InferenceError::UserCallback(_)) | Err(llm::InferenceError::EndOfText) => { unreachable!("cannot fail") } } @@ -108,7 +95,7 @@ fn infer(args: &cli_args::Infer) -> Result<()> { Ok(()) } -fn perplexity(args: &cli_args::Perplexity) -> Result<()> { +fn perplexity(args: &cli_args::Perplexity) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let inference_session_config = args.generate.inference_session_config(); let model = args.model_load.load(args.generate.use_gpu)?; @@ -128,10 +115,10 @@ fn perplexity(args: &cli_args::Perplexity) -> Result<()> { Ok(()) } -fn info(args: &cli_args::Info) -> Result<()> { +fn info(args: &cli_args::Info) -> eyre::Result<()> { struct InfoVisitor<'a>(&'a cli_args::Info); - impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> { - fn visit(&mut self) -> Result<()> { + impl llm::ModelArchitectureVisitor> for InfoVisitor<'_> { + fn visit(&mut self) -> eyre::Result<()> { let args = self.0; let model_path = &args.model_and_tokenizer.model_path; @@ -181,7 +168,7 @@ fn info(args: &cli_args::Info) -> Result<()> { .visit(&mut InfoVisitor(args)) } -fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> { +fn prompt_tokens(args: &cli_args::PromptTokens) -> eyre::Result<()> { let prompt = load_prompt_file_with_prompt(&args.prompt_file, args.prompt.as_deref())?; let model = args.model_load.load(false)?; let toks = match model.tokenizer().tokenize(&prompt, false) { @@ -210,184 +197,12 @@ fn prompt_tokens(args: &cli_args::PromptTokens) -> Result<()> { Ok(()) } -#[cfg(not(windows))] -fn force_newline_event_seq() -> KeyEvent { - KeyEvent(KeyCode::Enter, Modifiers::ALT) -} - -// On Windows, `SHIFT+ENTER` is the key sequence for forcing a newline. This is -// because `ALT+ENTER` typically maximizes the window. -#[cfg(windows)] -fn force_newline_event_seq() -> KeyEvent { - KeyEvent(KeyCode::Enter, Modifiers::SHIFT) -} - -fn repl(args: &cli_args::Repl) -> Result<()> { - interactive( - &args.generate, - &args.model_load, - false, - None, - args.prompt_file.contents()?.as_deref(), - ) -} - -fn chat(args: &cli_args::Chat) -> Result<()> { - interactive( - &args.generate, - &args.model_load, - true, - Some(std::fs::read_to_string(&args.prelude_prompt_file)?.as_str()), - Some(&args.message_prompt_prefix()?), - ) -} - -fn interactive( - generate: &cli_args::Generate, - model_load: &cli_args::ModelLoad, - chat_mode: bool, - mut initial_prompt_template: Option<&str>, - message_prompt_prefix: Option<&str>, -) -> Result<()> { - let inference_session_config = generate.inference_session_config(); - let model = model_load.load(generate.use_gpu)?; - - let recreate_session = - || snapshot::read_or_create_session(model.as_ref(), None, None, inference_session_config).0; - let mut session = recreate_session(); - - let parameters = generate.inference_parameters(model.eot_token_id()); - let mut rng = generate.rng(); - - fn session_ends_with_newline(session: &InferenceSession) -> bool { - session - .decoded_tokens() - .last() - .map(|t| *t == b'\n') - .unwrap_or(true) - } - - let mut infer = |session: &mut InferenceSession, mut prompt: String| { - // Add a newline to the beginning of the prompt if the last character in the session is not a newline - if !session_ends_with_newline(session) { - prompt.insert(0, '\n'); - } - - let sp = spinoff::Spinner::new(spinoff::spinners::Dots2, "".to_string(), None); - if let Err(InferenceError::ContextFull) = session.feed_prompt( - model.as_ref(), - ¶meters, - &prompt, - // OutputRequest - &mut Default::default(), - |_| Ok::<_, Infallible>(InferenceFeedback::Continue), - ) { - log::error!("Prompt exceeds context window length.") - }; - sp.clear(); - - fn print_token(t: String) { - print!("{t}"); - std::io::stdout().flush().unwrap(); - } - - if chat_mode { - let stop_sequence = message_prompt_prefix.unwrap_or_default().to_owned(); - - session.infer::( - model.as_ref(), - &mut rng, - &llm::InferenceRequest { - prompt: "".into(), - parameters: ¶meters, - play_back_previous_tokens: false, - maximum_token_count: generate.num_predict, - }, - &mut Default::default(), - conversation_inference_callback(stop_sequence, print_token), - ) - } else { - session.infer::( - model.as_ref(), - &mut rng, - &llm::InferenceRequest { - prompt: "".into(), - parameters: ¶meters, - play_back_previous_tokens: false, - maximum_token_count: generate.num_predict, - }, - &mut Default::default(), - |r| match r { - InferenceResponse::PromptToken(t) | InferenceResponse::InferredToken(t) => { - print_token(t); - - Ok(InferenceFeedback::Continue) - } - _ => Ok(InferenceFeedback::Continue), - }, - ) - } - }; - - let mut rl = rustyline::Editor::::new()?; - rl.set_helper(Some(LineContinuationValidator)); - rl.bind_sequence(force_newline_event_seq(), Cmd::Newline); - - loop { - let readline = rl.readline(">> "); - match readline { - Ok(raw_line) => { - let line = raw_line.replace("\\\n", "\n"); - - // Use the initial prompt template for the first inference, - // and then switch to the message prompt prefix afterwards. - let mut prompt = initial_prompt_template - .take() - .map(|template| process_prompt(template, &line)) - .unwrap_or_else(|| { - message_prompt_prefix - .map(|prefix| format!("{}{}", prefix, line)) - .unwrap_or_else(|| line) - }); - - // Add a newline to the end of the prompt if it doesn't end with one in chat mode - if chat_mode && !prompt.ends_with('\n') { - prompt.push('\n'); - } - - if let Err(err) = infer(&mut session, prompt) { - log::error!("{err}"); - break; - } - - if !session_ends_with_newline(&session) { - println!(); - } - - // Reload session in REPL mode - if !chat_mode { - session = recreate_session(); - } - } - Err(ReadlineError::Eof) | Err(ReadlineError::Interrupted) => { - break; - } - Err(err) => { - log::error!("{err}"); - break; - } - } - } - - Ok(()) -} - -fn quantize(args: &cli_args::Quantize) -> Result<()> { +fn quantize(args: &cli_args::Quantize) -> eyre::Result<()> { use llm::QuantizeProgress; struct QuantizeVisitor<'a>(&'a cli_args::Quantize); - impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { - fn visit(&mut self) -> Result<()> { + impl llm::ModelArchitectureVisitor> for QuantizeVisitor<'_> { + fn visit(&mut self) -> eyre::Result<()> { let args = self.0; let mut source: BufReader = BufReader::new(std::fs::File::open(&args.source)?); @@ -445,33 +260,11 @@ fn quantize(args: &cli_args::Quantize) -> Result<()> { fn load_prompt_file_with_prompt( prompt_file: &cli_args::PromptFile, prompt: Option<&str>, -) -> Result { - Ok(if let Some(prompt_file) = prompt_file.contents()? { - if let Some(prompt) = prompt { - process_prompt(&prompt_file, prompt) - } else { - prompt_file - } - } else if let Some(prompt) = prompt { - prompt.to_owned() - } else { - bail!("No prompt or prompt file was provided. See --help"); +) -> eyre::Result { + Ok(match (prompt_file.contents()?, prompt) { + (Some(prompt_file), None) => prompt_file, + (None, Some(prompt)) => prompt.to_owned(), + (Some(prompt_file), Some(prompt)) => util::process_prompt(&prompt_file, prompt), + (None, None) => eyre::bail!("No prompt or prompt file was provided. See --help"), }) } - -#[derive(Completer, Helper, Highlighter, Hinter, Debug, Clone, Copy)] -struct LineContinuationValidator; - -impl Validator for LineContinuationValidator { - fn validate(&self, ctx: &mut ValidationContext) -> rustyline::Result { - if ctx.input().ends_with('\\') { - Ok(ValidationResult::Incomplete) - } else { - Ok(ValidationResult::Valid(None)) - } - } -} - -fn process_prompt(raw_prompt: &str, prompt: &str) -> String { - raw_prompt.replace("{{PROMPT}}", prompt) -} diff --git a/binaries/llm-cli/src/util.rs b/binaries/llm-cli/src/util.rs new file mode 100644 index 00000000..a925e005 --- /dev/null +++ b/binaries/llm-cli/src/util.rs @@ -0,0 +1,10 @@ +use std::io::Write; + +pub fn process_prompt(raw_prompt: &str, prompt: &str) -> String { + raw_prompt.replace("{{PROMPT}}", prompt) +} + +pub fn print_token(t: String) { + print!("{t}"); + std::io::stdout().flush().unwrap(); +} diff --git a/crates/llm/examples/vicuna-chat.rs b/crates/llm/examples/vicuna-chat.rs index ecf73de4..75702bb1 100644 --- a/crates/llm/examples/vicuna-chat.rs +++ b/crates/llm/examples/vicuna-chat.rs @@ -97,7 +97,7 @@ fn main() { maximum_token_count: None, }, &mut Default::default(), - conversation_inference_callback(format!("{character_name}:"), print_token), + conversation_inference_callback(&format!("{character_name}:"), print_token), ) .unwrap_or_else(|e| panic!("{e}")); diff --git a/utils/prompts/pygmalion-message.txt b/utils/prompts/pygmalion-message.txt new file mode 100644 index 00000000..6089c80f --- /dev/null +++ b/utils/prompts/pygmalion-message.txt @@ -0,0 +1 @@ +You: \ No newline at end of file diff --git a/utils/prompts/pygmalion.txt b/utils/prompts/pygmalion-prelude.txt similarity index 71% rename from utils/prompts/pygmalion.txt rename to utils/prompts/pygmalion-prelude.txt index 8b3cbadb..d6629953 100644 --- a/utils/prompts/pygmalion.txt +++ b/utils/prompts/pygmalion-prelude.txt @@ -1,4 +1,3 @@ Assistant's Persona: Assistant is a highly intelligent language model trained to comply with user requests. -Assistant: How may I help you? -You: {{PROMPT}} \ No newline at end of file +Assistant: How may I help you? \ No newline at end of file diff --git a/utils/prompts/vicuna-message.txt b/utils/prompts/vicuna-message.txt new file mode 100644 index 00000000..6b32a757 --- /dev/null +++ b/utils/prompts/vicuna-message.txt @@ -0,0 +1 @@ +User: \ No newline at end of file diff --git a/utils/prompts/vicuna.txt b/utils/prompts/vicuna-prelude.txt similarity index 76% rename from utils/prompts/vicuna.txt rename to utils/prompts/vicuna-prelude.txt index 1da54727..9c1bf219 100644 --- a/utils/prompts/vicuna.txt +++ b/utils/prompts/vicuna-prelude.txt @@ -1,4 +1,3 @@ A chat between a human ("User") and an AI assistant ("Assistant"). The assistant gives helpful, detailed, and polite answers to the human's questions. -Assistant: How may I help you? -User: {{PROMPT}} \ No newline at end of file +Assistant: How may I help you? \ No newline at end of file