From 7290e34e3616d3759fd1a8e26fa55fe396b31afb Mon Sep 17 00:00:00 2001 From: Ian Armour Date: Wed, 6 Mar 2024 19:23:07 -0800 Subject: [PATCH 1/9] Adds support for GBNF grammar. --- crates/llama_cpp/src/grammar/mod.rs | 490 +++++++++++++++++++++++ crates/llama_cpp/src/lib.rs | 2 + crates/llama_cpp/src/standard_sampler.rs | 39 +- crates/llama_cpp_sys/build.rs | 1 + 4 files changed, 522 insertions(+), 10 deletions(-) create mode 100644 crates/llama_cpp/src/grammar/mod.rs diff --git a/crates/llama_cpp/src/grammar/mod.rs b/crates/llama_cpp/src/grammar/mod.rs new file mode 100644 index 0000000..3236672 --- /dev/null +++ b/crates/llama_cpp/src/grammar/mod.rs @@ -0,0 +1,490 @@ +//! The grammar module contains the grammar parser and the grammar struct. +//! +//! This allows creating a llama-cpp grammar. This is essentially a translation of the parser in +//! `common` to rust + +use std::collections::BTreeMap; +use std::fmt::{Debug, Formatter}; + +use llama_cpp_sys::{ + llama_grammar, llama_grammar_element, llama_gretype, llama_gretype_LLAMA_GRETYPE_ALT, llama_gretype_LLAMA_GRETYPE_CHAR, llama_gretype_LLAMA_GRETYPE_CHAR_ALT, llama_gretype_LLAMA_GRETYPE_CHAR_NOT, llama_gretype_LLAMA_GRETYPE_CHAR_RNG_UPPER, llama_gretype_LLAMA_GRETYPE_END, llama_gretype_LLAMA_GRETYPE_RULE_REF +}; +use std::ptr::NonNull; +use std::str::FromStr; +use tracing::error; + +/// Details of extraneous characters after a rule error. +#[derive(thiserror::Error, Debug)] +#[error("Extraneous chars after rule {name:?}: {chars:?}")] +pub struct ExtraneousCharsAfterRule { + /// The name of the rule being parsed + pub name: String, + /// the extraneous characters + pub chars: String, + /// the rest of the input, this is still to be parsed. + pub rest: String, +} + +/// There was an error parsing the grammar. +#[derive(thiserror::Error, Debug)] +#[allow(clippy::module_name_repetitions)] +pub enum GrammarParseError { + /// There was an unexpected end of input. + #[error("Unexpected end of input")] + UnexpectedEndOfInput { + /// the stage of parsing that was being performed when we ran out of input. + parse_stage: &'static str, + }, + /// There was unexpected characters after a rule name but before "::=". There can only be whitespace. + #[error("Unexpected Chars after name {name:?} and before \"::=\": {chars}")] + UnexpectedCharsAfterName { + /// the name of the rule being parsed + name: String, + /// the unexpected characters + chars: String, + }, + /// There was no "::=" after a rule name. + #[error("Expected ::= after name {name:?}")] + ExpectedEqualsAfterName { + /// the name of the rule being parsed + name: String, + }, + /// There was no closing bracket in a nested rule. + #[error("Expected closing bracket in nested rule {name:?}")] + MissingClosingBracketInNestedRule { + /// the name of the rule being parsed + name: String, + }, + /// There was no rule before a postfix operator. + #[error("Missing rule before postfix operator in {name:?}")] + ExpectedRuleBeforePostfixOperator { + /// the name of the rule being parsed + name: String, + }, + /// There was an incorrect hex size. + #[error("Expected hex number with size {expected_size}, but number was {actual:?}")] + IncorrectHexSize { + /// the expected size of the hex number + expected_size: usize, + /// the actual hex number + actual: String, + }, + /// An unknown escape character was found. + #[error("Unknown escape {escape:?}")] + UnknownEscape { + /// the unknown character + escape: char, + }, + /// Failed to parse hex from a string. + #[error("Failed to parse hex from {string}: {error}")] + ParseHexError { + /// the error that occurred when parsing the hex + #[source] + error: std::num::ParseIntError, + /// the string that was being parsed + string: String, + }, + /// there was not space after the name + // todo: is this actually an error? + #[error("Missing space after name in {rest:?}")] + MissingSpaceAfterName { + /// the rest of the input, this is still to be parsed. + rest: String, + }, + /// There was unexpected characters after the rule. + #[error("{0}")] + ExtraneousCharsAfterRule(ExtraneousCharsAfterRule), +} + +/// A grammar for llama-cpp. +#[allow(clippy::module_name_repetitions)] +pub struct LlamaGrammar { + parse: ParseState, + pub(crate) grammar: NonNull, +} + +impl Clone for LlamaGrammar { + fn clone(&self) -> Self { + let grammar = unsafe { llama_cpp_sys::llama_grammar_copy(self.grammar.as_ptr()) }; + Self { + parse: self.parse.clone(), + grammar: NonNull::new(grammar).expect("copied grammar should never be null"), + } + } +} + +unsafe impl Send for LlamaGrammar {} + +unsafe impl Sync for LlamaGrammar {} + +#[allow(clippy::module_name_repetitions)] +impl Debug for LlamaGrammar { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + f.debug_struct("LlamaGrammar") + .field("grammar", &self.grammar) + .field("parse", &self.parse) + .finish() + } +} + +#[derive(Debug, Clone, PartialEq)] +struct ParseState { + symbol_ids: BTreeMap, + rules: Vec>, +} + +impl ParseState { + fn new() -> Self { + Self { + symbol_ids: BTreeMap::new(), + rules: Vec::new(), + } + } + + fn get_symbol_id(&mut self, name: &str) -> u32 { + let next_id = + u32::try_from(self.symbol_ids.len()).expect("too many rules (must fit into u32)"); + let result = self.symbol_ids.entry(name.to_string()).or_insert(next_id); + *result + } + + fn generate_symbol_id(&mut self, name: &str) -> u32 { + let next_id = + u32::try_from(self.symbol_ids.len()).expect("too many rules (must fit into u32)"); + let generated_name = format!("{name}_{next_id}"); + let None = self.symbol_ids.insert(generated_name, next_id) else { + panic!("Failed to create unique name for {name}"); + }; + next_id + } + + fn parse_rule<'a>(&mut self, rest: &'a str) -> Result, GrammarParseError> { + let rest = Self::consume_whitespace_and_comments(rest, true); + if rest.is_empty() { + return Ok(None); + } + let (name, rest) = Self::parse_name(rest)?; + let rest = rest.trim_start(); + let rule_id = self.get_symbol_id(name); + + let (after_name, rest) = + rest.split_once("::=") + .ok_or_else(|| GrammarParseError::ExpectedEqualsAfterName { + name: name.to_string(), + })?; + + if !after_name.is_empty() { + return Err(GrammarParseError::UnexpectedCharsAfterName { + name: name.to_string(), + chars: after_name.to_string(), + }); + } + + let rest = self.parse_alternatives(name, rule_id, rest, false)?; + + let Some((after_rule, rest)) = rest.split_once('\n') else { + return Ok(None); + }; + + if !after_rule.chars().all(char::is_whitespace) { + return Err(GrammarParseError::ExtraneousCharsAfterRule( + ExtraneousCharsAfterRule { + name: name.to_string(), + chars: after_rule.to_string(), + rest: rest.to_string(), + }, + )); + } + + Ok(Some(rest)) + } + + fn consume_whitespace_and_comments(mut rest: &str, allow_newlines: bool) -> &str { + loop { + rest = rest.trim_start_matches( + |c: char| if allow_newlines { true } else { c != '\n' } && c.is_whitespace(), + ); + if rest.starts_with('#') { + rest = rest.split_once('\n').map_or("", |(_comment, rest)| rest); + } else { + break; + } + } + rest + } + + fn parse_alternatives<'a>( + &mut self, + name: &str, + id: u32, + rest: &'a str, + nested: bool, + ) -> Result<&'a str, GrammarParseError> { + let mut rule = Vec::new(); + let rest = self.parse_sequence(rest.trim_start(), name, &mut rule, nested)?; + let mut rest = Self::consume_whitespace_and_comments(rest, nested); + while rest.starts_with('|') { + rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_ALT, + value: 0, + }); + rest = Self::consume_whitespace_and_comments(&rest[1..], true); + rest = self.parse_sequence(rest, name, &mut rule, nested)?; + } + rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_END, + value: 0, + }); + self.add_rule(id, rule); + Ok(rest) + } + + fn add_rule(&mut self, id: u32, rule: Vec) { + let id = id as usize; + if self.rules.len() <= id { + self.rules.resize(id + 1, Vec::new()); + } + self.rules[id] = rule; + } + + #[allow(clippy::too_many_lines)] + fn parse_sequence<'a>( + &mut self, + mut rest: &'a str, + name: &str, + rule: &mut Vec, + nested: bool, + ) -> Result<&'a str, GrammarParseError> { + let mut last_sym_start = rule.len(); + while !rest.is_empty() { + let first_char = + rest.chars() + .next() + .ok_or(GrammarParseError::UnexpectedEndOfInput { + parse_stage: "sequence", + })?; + if first_char == '"' { + rest = &rest[1..]; + last_sym_start = rule.len(); + while !rest.starts_with('"') { + let (c, r) = Self::parse_char(rest)?; + rest = r; + rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_CHAR, + value: c as _, + }); + } + rest = Self::consume_whitespace_and_comments(&rest[1..], nested); + } else if first_char == '[' { + rest = &rest[1..]; + let start_type = if rest.starts_with('^') { + rest = &rest[1..]; + llama_gretype_LLAMA_GRETYPE_CHAR_NOT + } else { + llama_gretype_LLAMA_GRETYPE_CHAR + }; + last_sym_start = rule.len(); + while !rest.starts_with(']') { + let (c, r) = Self::parse_char(rest)?; + rest = r; + let gre_type = if last_sym_start < rule.len() { + llama_gretype_LLAMA_GRETYPE_CHAR_ALT + } else { + start_type + }; + rule.push(llama_grammar_element { + type_: gre_type, + value: c as _, + }); + if rest.starts_with("-]") { + let (c, r) = Self::parse_char(rest)?; + rest = r; + rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_CHAR_RNG_UPPER, + value: c as _, + }); + } + } + rest = Self::consume_whitespace_and_comments(&rest[1..], nested); + } else if first_char.is_alphabetic() { + let (name, r) = Self::parse_name(rest)?; + rest = Self::consume_whitespace_and_comments(r, nested); + let ref_rule_id = self.get_symbol_id(name); + last_sym_start = rule.len(); + rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_RULE_REF, + value: ref_rule_id, + }); + } else if first_char == '(' { + rest = rest[1..].trim_start(); + let sub_rule_id = self.generate_symbol_id(name); + rest = self.parse_alternatives(name, sub_rule_id, rest, true)?; + last_sym_start = rule.len(); + rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_RULE_REF, + value: sub_rule_id, + }); + if !rest.starts_with(')') { + return Err(GrammarParseError::MissingClosingBracketInNestedRule { + name: name.to_string(), + }); + } + rest = Self::consume_whitespace_and_comments(&rest[1..], nested); + } else if first_char == '*' || first_char == '+' || first_char == '?' { + if last_sym_start == rule.len() { + return Err(GrammarParseError::ExpectedRuleBeforePostfixOperator { + name: name.to_string(), + }); + } + let sub_rule_id = self.generate_symbol_id(name); + let mut sub_rule: Vec = + rule.iter().skip(last_sym_start).copied().collect(); + if rest.starts_with(['*', '+']) { + sub_rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_RULE_REF, + value: sub_rule_id, + }); + } + sub_rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_ALT, + value: 0, + }); + if rest.starts_with('+') { + sub_rule.extend(rule.iter().skip(last_sym_start).copied()); + } + sub_rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_END, + value: 0, + }); + self.add_rule(sub_rule_id, sub_rule); + + rule.truncate(last_sym_start); + rule.push(llama_grammar_element { + type_: llama_gretype_LLAMA_GRETYPE_RULE_REF, + value: sub_rule_id, + }); + + rest = Self::consume_whitespace_and_comments(&rest[1..], nested); + } else { + break; + } + } + + Ok(rest) + } + + fn parse_hex(rest: &str, size: usize) -> Result<(llama_gretype, &str), GrammarParseError> { + if rest.len() < size { + return Err(GrammarParseError::IncorrectHexSize { + expected_size: size, + actual: rest.to_string(), + }); + } + + let (hex, rest) = rest.split_at(size); + let value = + u32::from_str_radix(hex, 16).map_err(|error| GrammarParseError::ParseHexError { + string: hex.to_string(), + error, + })?; + + Ok((value as llama_gretype, rest)) + } + + fn parse_char(rest: &str) -> Result<(llama_gretype, &str), GrammarParseError> { + if let Some(rest) = rest.strip_prefix('\\') { + let Some(escaped) = rest.chars().next() else { + return Err(GrammarParseError::UnexpectedEndOfInput { + parse_stage: "escape char", + }); + }; + let rest = &rest[escaped.len_utf8()..]; + match escaped { + 'x' => Self::parse_hex(rest, 2), + 'u' => Self::parse_hex(rest, 4), + 'U' => Self::parse_hex(rest, 8), + 't' => Ok((u32::from('\t') as llama_gretype, rest)), + 'r' => Ok((u32::from('\r') as llama_gretype, rest)), + 'n' => Ok((u32::from('\n') as llama_gretype, rest)), + '\\' => Ok((u32::from('\\') as llama_gretype, rest)), + '"' => Ok((u32::from('"') as llama_gretype, rest)), + '[' => Ok((u32::from('[') as llama_gretype, rest)), + ']' => Ok((u32::from(']') as llama_gretype, rest)), + c => Err(GrammarParseError::UnknownEscape { escape: c }), + } + } else if let Some(c) = rest.chars().next() { + Ok((u32::from(c) as llama_gretype, &rest[c.len_utf8()..])) + } else { + Err(GrammarParseError::UnexpectedEndOfInput { + parse_stage: "char", + }) + } + } + + fn parse_name(rest: &str) -> Result<(&str, &str), GrammarParseError> { + let name_end = rest + .find(|c: char| !c.is_alphanumeric() && c != '-' && c != '_') + .ok_or(GrammarParseError::MissingSpaceAfterName { + rest: rest.to_string(), + })?; + let name = &rest[..name_end]; + let rest = &rest[name_end..]; + Ok((name, rest)) + } +} + +/// An error that can occur creating a grammar from a string. +#[derive(thiserror::Error, Debug)] +pub enum LlamaGrammarFromStrError { + /// There was an error parsing the grammar. + #[error("Failed to parse grammar {0}")] + ParseError(#[from] GrammarParseError), + /// Llama-cpp returned null - this can occur for many reasons, but should ideally be caught on + /// the rust side beforehand. + #[error("llama-cpp returned null")] + LlamaCppNullError, +} + +impl FromStr for ParseState { + type Err = GrammarParseError; + + fn from_str(s: &str) -> Result { + let mut parse_state = ParseState::new(); + let mut remaining = Some(s); + while let Some(str) = remaining { + remaining = parse_state.parse_rule(str)?; + } + Ok(parse_state) + } +} + +impl FromStr for LlamaGrammar { + type Err = LlamaGrammarFromStrError; + + fn from_str(s: &str) -> Result { + let mut parse_state = ParseState::from_str(s)?; + + let n_rules = parse_state.rules.len(); + let root_id = parse_state.get_symbol_id("root"); + let mut vec = parse_state + .rules + .iter_mut() + .map(|v| v.as_ptr()) + .collect::>(); + let rules = vec.as_mut_ptr(); + + let grammar = + unsafe { llama_cpp_sys::llama_grammar_init(rules, n_rules, root_id as usize) }; + + Ok(Self { + parse: parse_state, + grammar: NonNull::new(grammar).ok_or(LlamaGrammarFromStrError::LlamaCppNullError)?, + }) + } +} + +impl Drop for LlamaGrammar { + fn drop(&mut self) { + unsafe { llama_cpp_sys::llama_grammar_free(self.grammar.as_ptr()) } + } +} \ No newline at end of file diff --git a/crates/llama_cpp/src/lib.rs b/crates/llama_cpp/src/lib.rs index 7ff595a..8b4d345 100644 --- a/crates/llama_cpp/src/lib.rs +++ b/crates/llama_cpp/src/lib.rs @@ -77,6 +77,7 @@ #![warn(missing_docs)] +use grammar::LlamaGrammar; use llama_cpp_sys::{llama_context, llama_token_data_array}; use thiserror::Error; @@ -88,6 +89,7 @@ mod session; pub use model::*; pub use session::*; +pub mod grammar; /// The standard sampler implementation. pub mod standard_sampler; diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index 8294503..e96ebb3 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -1,13 +1,10 @@ use std::ptr::addr_of_mut; use llama_cpp_sys::{ - llama_context, llama_sample_entropy, llama_sample_min_p, llama_sample_repetition_penalties, - llama_sample_tail_free, llama_sample_temp, llama_sample_token, llama_sample_token_greedy, - llama_sample_token_mirostat, llama_sample_token_mirostat_v2, llama_sample_top_k, - llama_sample_top_p, llama_sample_typical, llama_token, llama_token_data_array, + llama_context, llama_sample_entropy, llama_sample_grammar, llama_grammar_accept_token, llama_sample_min_p, llama_sample_repetition_penalties, llama_sample_tail_free, llama_sample_temp, llama_sample_token, llama_sample_token_greedy, llama_sample_token_mirostat, llama_sample_token_mirostat_v2, llama_sample_top_k, llama_sample_top_p, llama_sample_typical, llama_token, llama_token_data_array }; -use crate::{Sampler, Token}; +use crate::{grammar::LlamaGrammar, Sampler, Token}; /// Functions which modify the probability distribution output by the model. /// @@ -19,7 +16,7 @@ use crate::{Sampler, Token}; /// 4. [`SamplerStage::TailFree`] /// 5. [`SamplerStage::Typical`] /// 6. [`SamplerStage::TopP`], [`SamplerStage::MinP`] -#[derive(Clone, Debug)] +#[derive(Debug)] #[non_exhaustive] pub enum SamplerStage { /// Divide the logits by this value. Ranges from 0 to 2. Lower values yield a more @@ -52,7 +49,6 @@ pub enum SamplerStage { /// temperature to approach `max_temp` more quickly at small entropies. exponent_val: f32, }, - /// Penalizes generating a token that is within the `last_n` tokens of context in various ways. RepetitionPenalty { /// Divide the token's logit by this value if they appear one or more time in the `last_n` @@ -228,10 +224,11 @@ impl TokenSelector { /// Selects a token after applying multiple [`SamplerStage`]'s to the /// probability distribution output by the model. -#[derive(Clone, Debug)] +#[derive(Debug)] pub struct StandardSampler { stages: Vec, min_keep: usize, + grammar: Option, token_selector: TokenSelector, } @@ -242,10 +239,15 @@ impl StandardSampler { /// /// Ensures that at least `min_keep` tokens remain after the /// [`SamplerStage`]'s are applied. - pub fn new_softmax(stages: Vec, min_keep: usize) -> StandardSampler { + pub fn new_softmax( + stages: Vec, + min_keep: usize, + grammar: Option, + ) -> StandardSampler { StandardSampler { stages, min_keep, + grammar: grammar, token_selector: TokenSelector::Softmax, } } @@ -256,6 +258,7 @@ impl StandardSampler { StandardSampler { stages: Vec::new(), min_keep: 0, + grammar: None, token_selector: TokenSelector::Greedy, } } @@ -272,6 +275,7 @@ impl StandardSampler { StandardSampler { stages, min_keep, + grammar: None, token_selector: TokenSelector::Mirostat { tau, eta, @@ -292,6 +296,7 @@ impl StandardSampler { StandardSampler { stages, min_keep, + grammar: None, token_selector: TokenSelector::MirostatV2 { tau, eta, @@ -316,6 +321,7 @@ impl Default for StandardSampler { SamplerStage::MinP(0.05), SamplerStage::Temperature(0.8), ], + grammar: None, min_keep: 1, token_selector: TokenSelector::Softmax, } @@ -330,12 +336,25 @@ impl Sampler for StandardSampler { tokens: &[Token], mut candidates_p: llama_token_data_array, ) -> Token { + let p_ptr = addr_of_mut!(candidates_p); let min_keep = self.min_keep.max(1); + // Note: We should sample grammar before applying other sampling stages. + if let Some(grammar) = self.grammar.as_mut() { + unsafe { llama_sample_grammar(context, p_ptr, grammar.grammar.as_ptr()) }; + } + for stage in &self.stages { candidates_p = stage.apply(context, tokens, candidates_p, min_keep); } - self.token_selector.select(context, candidates_p) + let token = self.token_selector.select(context, candidates_p); + + // Note: We must accept the token into the grammar after sampling if a grammar is provided. + if let Some(grammar) = self.grammar.as_mut() { + unsafe { llama_grammar_accept_token(context, grammar.grammar.as_ptr(), token.0)} + } + + token } } diff --git a/crates/llama_cpp_sys/build.rs b/crates/llama_cpp_sys/build.rs index 201e0a6..9464d50 100644 --- a/crates/llama_cpp_sys/build.rs +++ b/crates/llama_cpp_sys/build.rs @@ -91,6 +91,7 @@ fn compile_bindings(out_path: &Path) { let bindings = bindgen::Builder::default() .header(LLAMA_PATH.join("ggml.h").to_string_lossy()) .header(LLAMA_PATH.join("llama.h").to_string_lossy()) + .derive_partialeq(true) .allowlist_type("ggml_.*") .allowlist_function("llama_.*") .allowlist_type("llama_.*") From e3f0b863a3eba78fb6c41e9c83e6b10af4f86465 Mon Sep 17 00:00:00 2001 From: Ian Armour Date: Fri, 8 Mar 2024 14:44:20 -0800 Subject: [PATCH 2/9] Added back in Clone derive for sampler structs --- crates/llama_cpp/src/standard_sampler.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index e96ebb3..d0e6d56 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -16,7 +16,7 @@ use crate::{grammar::LlamaGrammar, Sampler, Token}; /// 4. [`SamplerStage::TailFree`] /// 5. [`SamplerStage::Typical`] /// 6. [`SamplerStage::TopP`], [`SamplerStage::MinP`] -#[derive(Debug)] +#[derive(Clone, Debug)] #[non_exhaustive] pub enum SamplerStage { /// Divide the logits by this value. Ranges from 0 to 2. Lower values yield a more @@ -224,7 +224,7 @@ impl TokenSelector { /// Selects a token after applying multiple [`SamplerStage`]'s to the /// probability distribution output by the model. -#[derive(Debug)] +#[derive(Clone, Debug)] pub struct StandardSampler { stages: Vec, min_keep: usize, From 7d1bc7c568d3c05a506afda733e0565df2f5e5d6 Mon Sep 17 00:00:00 2001 From: Ian Armour Date: Fri, 8 Mar 2024 14:45:54 -0800 Subject: [PATCH 3/9] Added missing newline --- crates/llama_cpp/src/grammar/mod.rs | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/crates/llama_cpp/src/grammar/mod.rs b/crates/llama_cpp/src/grammar/mod.rs index 3236672..7e1c19b 100644 --- a/crates/llama_cpp/src/grammar/mod.rs +++ b/crates/llama_cpp/src/grammar/mod.rs @@ -487,4 +487,4 @@ impl Drop for LlamaGrammar { fn drop(&mut self) { unsafe { llama_cpp_sys::llama_grammar_free(self.grammar.as_ptr()) } } -} \ No newline at end of file +} From c1d6aa131993ed1659e63c6185bec205e8f892e2 Mon Sep 17 00:00:00 2001 From: Ian Armour Date: Fri, 8 Mar 2024 14:56:21 -0800 Subject: [PATCH 4/9] Fixed formatting of changes. --- crates/llama_cpp/src/grammar/mod.rs | 5 ++++- crates/llama_cpp/src/standard_sampler.rs | 8 ++++++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/crates/llama_cpp/src/grammar/mod.rs b/crates/llama_cpp/src/grammar/mod.rs index 7e1c19b..6c82b65 100644 --- a/crates/llama_cpp/src/grammar/mod.rs +++ b/crates/llama_cpp/src/grammar/mod.rs @@ -7,7 +7,10 @@ use std::collections::BTreeMap; use std::fmt::{Debug, Formatter}; use llama_cpp_sys::{ - llama_grammar, llama_grammar_element, llama_gretype, llama_gretype_LLAMA_GRETYPE_ALT, llama_gretype_LLAMA_GRETYPE_CHAR, llama_gretype_LLAMA_GRETYPE_CHAR_ALT, llama_gretype_LLAMA_GRETYPE_CHAR_NOT, llama_gretype_LLAMA_GRETYPE_CHAR_RNG_UPPER, llama_gretype_LLAMA_GRETYPE_END, llama_gretype_LLAMA_GRETYPE_RULE_REF + llama_grammar, llama_grammar_element, llama_gretype, llama_gretype_LLAMA_GRETYPE_ALT, + llama_gretype_LLAMA_GRETYPE_CHAR, llama_gretype_LLAMA_GRETYPE_CHAR_ALT, + llama_gretype_LLAMA_GRETYPE_CHAR_NOT, llama_gretype_LLAMA_GRETYPE_CHAR_RNG_UPPER, + llama_gretype_LLAMA_GRETYPE_END, llama_gretype_LLAMA_GRETYPE_RULE_REF, }; use std::ptr::NonNull; use std::str::FromStr; diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index d0e6d56..712f126 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -1,7 +1,11 @@ use std::ptr::addr_of_mut; use llama_cpp_sys::{ - llama_context, llama_sample_entropy, llama_sample_grammar, llama_grammar_accept_token, llama_sample_min_p, llama_sample_repetition_penalties, llama_sample_tail_free, llama_sample_temp, llama_sample_token, llama_sample_token_greedy, llama_sample_token_mirostat, llama_sample_token_mirostat_v2, llama_sample_top_k, llama_sample_top_p, llama_sample_typical, llama_token, llama_token_data_array + llama_context, llama_grammar_accept_token, llama_sample_entropy, llama_sample_grammar, + llama_sample_min_p, llama_sample_repetition_penalties, llama_sample_tail_free, + llama_sample_temp, llama_sample_token, llama_sample_token_greedy, llama_sample_token_mirostat, + llama_sample_token_mirostat_v2, llama_sample_top_k, llama_sample_top_p, llama_sample_typical, + llama_token, llama_token_data_array, }; use crate::{grammar::LlamaGrammar, Sampler, Token}; @@ -352,7 +356,7 @@ impl Sampler for StandardSampler { // Note: We must accept the token into the grammar after sampling if a grammar is provided. if let Some(grammar) = self.grammar.as_mut() { - unsafe { llama_grammar_accept_token(context, grammar.grammar.as_ptr(), token.0)} + unsafe { llama_grammar_accept_token(context, grammar.grammar.as_ptr(), token.0) } } token From b6a8a069959d1f966eb566d705e8768e6eaff0f4 Mon Sep 17 00:00:00 2001 From: Ian Armour Date: Fri, 8 Mar 2024 15:14:59 -0800 Subject: [PATCH 5/9] Remove extra imports --- crates/llama_cpp/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/crates/llama_cpp/src/lib.rs b/crates/llama_cpp/src/lib.rs index 8b4d345..e05f6b1 100644 --- a/crates/llama_cpp/src/lib.rs +++ b/crates/llama_cpp/src/lib.rs @@ -77,7 +77,6 @@ #![warn(missing_docs)] -use grammar::LlamaGrammar; use llama_cpp_sys::{llama_context, llama_token_data_array}; use thiserror::Error; From 65238d355f4613117302b581e3931243b9d31e62 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 10 Mar 2024 22:41:19 -0500 Subject: [PATCH 6/9] Implement grammar sampling as a SamplerStage --- crates/llama_cpp/src/standard_sampler.rs | 83 +++++++++++++++++------- 1 file changed, 60 insertions(+), 23 deletions(-) diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index 712f126..e4026cf 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -103,6 +103,12 @@ pub enum SamplerStage { /// /// See: TailFree(f32), + + /// A stage that uses a [`LlamaGrammar`] to remove tokens that do not align with a given + /// grammar. + /// + /// See [`GrammarStage`] and [`LlamaGrammar`] for more information. + Grammar(GrammarStage), } impl SamplerStage { @@ -112,7 +118,7 @@ impl SamplerStage { /// [`SamplerStage`]'s are applied. #[allow(clippy::not_unsafe_ptr_arg_deref)] pub fn apply( - &self, + &mut self, context: *mut llama_context, tokens: &[Token], mut candidates_p: llama_token_data_array, @@ -173,6 +179,9 @@ impl SamplerStage { SamplerStage::TailFree(z) => { llama_sample_tail_free(context, p_ptr, *z, min_keep); } + SamplerStage::Grammar(stage) => { + stage.apply(context, tokens, candidates_p, min_keep) + } } } @@ -180,6 +189,54 @@ impl SamplerStage { } } +/// Opaque internals for [`SamplerStage::Grammar`]. +#[derive(Clone, Debug)] +pub struct GrammarStage { + original_grammar: LlamaGrammar, + grammar: LlamaGrammar, + tokens: Vec, +} + +impl GrammarStage { + /// Creates a new [`GrammarStage`] from a [`LlamaGrammar`] + pub fn new(grammar: LlamaGrammar) -> Self { + Self { + original_grammar: grammar.clone(), + grammar, + tokens: Vec::new() + } + } + + /// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`] + pub fn new_stage(grammar: LlamaGrammar) -> SamplerStage { + SamplerStage::Grammar(Self::new(grammar)) + } + + fn apply( + &mut self, + context: *mut llama_context, + tokens: &[Token], + mut candidates_p: llama_token_data_array, + _min_keep: usize, + ) { + let new_tokens = if let Some(suffix) = tokens.strip_prefix(self.tokens.as_slice()) { + suffix + } else { + self.tokens.clear(); + self.grammar = self.original_grammar.clone(); + tokens + }; + + for token in new_tokens { + unsafe { llama_grammar_accept_token(context, self.grammar.grammar.as_ptr(), token.0) } + } + self.tokens.extend_from_slice(new_tokens); + + let p_ptr = addr_of_mut!(candidates_p); + unsafe { llama_sample_grammar(context, p_ptr, self.grammar.grammar.as_ptr()) }; + } +} + /// Determines how the next token is selected from the distribution produced by /// the model and the [`SamplerStage`]'s. #[derive(Clone, Debug)] @@ -232,7 +289,6 @@ impl TokenSelector { pub struct StandardSampler { stages: Vec, min_keep: usize, - grammar: Option, token_selector: TokenSelector, } @@ -246,12 +302,10 @@ impl StandardSampler { pub fn new_softmax( stages: Vec, min_keep: usize, - grammar: Option, ) -> StandardSampler { StandardSampler { stages, min_keep, - grammar: grammar, token_selector: TokenSelector::Softmax, } } @@ -262,7 +316,6 @@ impl StandardSampler { StandardSampler { stages: Vec::new(), min_keep: 0, - grammar: None, token_selector: TokenSelector::Greedy, } } @@ -279,7 +332,6 @@ impl StandardSampler { StandardSampler { stages, min_keep, - grammar: None, token_selector: TokenSelector::Mirostat { tau, eta, @@ -300,7 +352,6 @@ impl StandardSampler { StandardSampler { stages, min_keep, - grammar: None, token_selector: TokenSelector::MirostatV2 { tau, eta, @@ -325,7 +376,6 @@ impl Default for StandardSampler { SamplerStage::MinP(0.05), SamplerStage::Temperature(0.8), ], - grammar: None, min_keep: 1, token_selector: TokenSelector::Softmax, } @@ -340,25 +390,12 @@ impl Sampler for StandardSampler { tokens: &[Token], mut candidates_p: llama_token_data_array, ) -> Token { - let p_ptr = addr_of_mut!(candidates_p); let min_keep = self.min_keep.max(1); - // Note: We should sample grammar before applying other sampling stages. - if let Some(grammar) = self.grammar.as_mut() { - unsafe { llama_sample_grammar(context, p_ptr, grammar.grammar.as_ptr()) }; - } - - for stage in &self.stages { + for stage in &mut self.stages { candidates_p = stage.apply(context, tokens, candidates_p, min_keep); } - let token = self.token_selector.select(context, candidates_p); - - // Note: We must accept the token into the grammar after sampling if a grammar is provided. - if let Some(grammar) = self.grammar.as_mut() { - unsafe { llama_grammar_accept_token(context, grammar.grammar.as_ptr(), token.0) } - } - - token + self.token_selector.select(context, candidates_p) } } From 62ee32613390da207da707f294885bf98303f650 Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 10 Mar 2024 23:04:44 -0500 Subject: [PATCH 7/9] Fix bugs; Simplify api --- crates/llama_cpp/src/session/params.rs | 2 +- crates/llama_cpp/src/standard_sampler.rs | 44 +++++++++--------------- 2 files changed, 18 insertions(+), 28 deletions(-) diff --git a/crates/llama_cpp/src/session/params.rs b/crates/llama_cpp/src/session/params.rs index 738615b..066ca1b 100644 --- a/crates/llama_cpp/src/session/params.rs +++ b/crates/llama_cpp/src/session/params.rs @@ -105,7 +105,7 @@ pub struct SessionParams { pub pooling: PoolingType, /// defragment the KV cache if holes/size > thold, < 0 disabled (default) - defrag_threshold: f32, + pub defrag_threshold: f32, } impl Default for SessionParams { diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index e4026cf..65defbf 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -105,13 +105,25 @@ pub enum SamplerStage { TailFree(f32), /// A stage that uses a [`LlamaGrammar`] to remove tokens that do not align with a given - /// grammar. + /// grammar. Since this stage has to handle mutable state, an instance of this stage should + /// only be used in one completion. /// /// See [`GrammarStage`] and [`LlamaGrammar`] for more information. Grammar(GrammarStage), } impl SamplerStage { + /// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`]. + /// + /// `start_position` indicates the token position to begin applying the grammar at. [`None`] + /// indicates that the grammar begins at the end of context. + pub fn from_grammar(grammar: LlamaGrammar, start_position: Option) -> Self { + SamplerStage::Grammar(GrammarStage { + grammar, + accepted_to: start_position, + }) + } + /// Applies this [`SamplerStage`] to the provided token data array. /// /// Ensures that at least `min_keep` tokens remain after the @@ -192,26 +204,11 @@ impl SamplerStage { /// Opaque internals for [`SamplerStage::Grammar`]. #[derive(Clone, Debug)] pub struct GrammarStage { - original_grammar: LlamaGrammar, grammar: LlamaGrammar, - tokens: Vec, + accepted_to: Option, } impl GrammarStage { - /// Creates a new [`GrammarStage`] from a [`LlamaGrammar`] - pub fn new(grammar: LlamaGrammar) -> Self { - Self { - original_grammar: grammar.clone(), - grammar, - tokens: Vec::new() - } - } - - /// Creates a new [`SamplerStage::Grammar`] from a [`LlamaGrammar`] - pub fn new_stage(grammar: LlamaGrammar) -> SamplerStage { - SamplerStage::Grammar(Self::new(grammar)) - } - fn apply( &mut self, context: *mut llama_context, @@ -219,18 +216,11 @@ impl GrammarStage { mut candidates_p: llama_token_data_array, _min_keep: usize, ) { - let new_tokens = if let Some(suffix) = tokens.strip_prefix(self.tokens.as_slice()) { - suffix - } else { - self.tokens.clear(); - self.grammar = self.original_grammar.clone(); - tokens - }; - - for token in new_tokens { + let accepted_to = self.accepted_to.unwrap_or(tokens.len()); + for token in &tokens[accepted_to..] { unsafe { llama_grammar_accept_token(context, self.grammar.grammar.as_ptr(), token.0) } } - self.tokens.extend_from_slice(new_tokens); + self.accepted_to = Some(tokens.len()); let p_ptr = addr_of_mut!(candidates_p); unsafe { llama_sample_grammar(context, p_ptr, self.grammar.grammar.as_ptr()) }; From 7230cc63d11de8dd9760db7f6cfabe1b3e435ccd Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Sun, 10 Mar 2024 23:07:16 -0500 Subject: [PATCH 8/9] Update docs --- crates/llama_cpp/src/standard_sampler.rs | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index 65defbf..52a82df 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -14,12 +14,13 @@ use crate::{grammar::LlamaGrammar, Sampler, Token}; /// /// Standard ordering for samplers (taken from [kobold.cpp](https://github.com/LostRuins/koboldcpp)): /// -/// 1. [`SamplerStage::RepetitionPenalty`] -/// 2. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature] -/// 3. [`SamplerStage::TopK`] -/// 4. [`SamplerStage::TailFree`] -/// 5. [`SamplerStage::Typical`] -/// 6. [`SamplerStage::TopP`], [`SamplerStage::MinP`] +/// 1. [`SamplerStage::Grammar`] +/// 2. [`SamplerStage::RepetitionPenalty`] +/// 3. [`SamplerStage::Temperature`], [SamplerStage::DynamicTemperature] +/// 4. [`SamplerStage::TopK`] +/// 5. [`SamplerStage::TailFree`] +/// 6. [`SamplerStage::Typical`] +/// 7. [`SamplerStage::TopP`], [`SamplerStage::MinP`] #[derive(Clone, Debug)] #[non_exhaustive] pub enum SamplerStage { From e7e0f934ccfd43c68748bf4b7b43ec3cb7cc9c8c Mon Sep 17 00:00:00 2001 From: Nathan Koppel Date: Mon, 11 Mar 2024 12:00:55 -0500 Subject: [PATCH 9/9] Make changes more readable --- crates/llama_cpp/src/standard_sampler.rs | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/crates/llama_cpp/src/standard_sampler.rs b/crates/llama_cpp/src/standard_sampler.rs index 52a82df..e280ee7 100644 --- a/crates/llama_cpp/src/standard_sampler.rs +++ b/crates/llama_cpp/src/standard_sampler.rs @@ -121,7 +121,7 @@ impl SamplerStage { pub fn from_grammar(grammar: LlamaGrammar, start_position: Option) -> Self { SamplerStage::Grammar(GrammarStage { grammar, - accepted_to: start_position, + accepted_up_to: start_position, }) } @@ -193,7 +193,7 @@ impl SamplerStage { llama_sample_tail_free(context, p_ptr, *z, min_keep); } SamplerStage::Grammar(stage) => { - stage.apply(context, tokens, candidates_p, min_keep) + candidates_p = stage.apply(context, tokens, candidates_p, min_keep) } } } @@ -206,7 +206,7 @@ impl SamplerStage { #[derive(Clone, Debug)] pub struct GrammarStage { grammar: LlamaGrammar, - accepted_to: Option, + accepted_up_to: Option, } impl GrammarStage { @@ -216,15 +216,21 @@ impl GrammarStage { tokens: &[Token], mut candidates_p: llama_token_data_array, _min_keep: usize, - ) { - let accepted_to = self.accepted_to.unwrap_or(tokens.len()); - for token in &tokens[accepted_to..] { + ) -> llama_token_data_array { + // If `accepted_up_to` is `None`, assume that we should start at the end of context. + let accepted_up_to = self.accepted_up_to.unwrap_or(tokens.len()); + + // Accept all new tokens until the end of context. + for token in &tokens[accepted_up_to..] { unsafe { llama_grammar_accept_token(context, self.grammar.grammar.as_ptr(), token.0) } } - self.accepted_to = Some(tokens.len()); + self.accepted_up_to = Some(tokens.len()); + // Apply grammar sampling to `candidates_p`. let p_ptr = addr_of_mut!(candidates_p); unsafe { llama_sample_grammar(context, p_ptr, self.grammar.grammar.as_ptr()) }; + + candidates_p } }