diff --git a/Cargo.lock b/Cargo.lock index 089210b9c..d99860bae 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -46,7 +46,9 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", + "getrandom", "once_cell", + "serde", "version_check", "zerocopy", ] @@ -328,7 +330,7 @@ version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ - "generic-array", + "generic-array 0.14.7", ] [[package]] @@ -731,7 +733,7 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ - "generic-array", + "generic-array 0.14.7", "typenum", ] @@ -1109,6 +1111,12 @@ dependencies = [ "windows-sys 0.59.0", ] +[[package]] +name = "fixedbitset-stack" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2d498da3487b4ea426e370276db9e93c29624b667855b415ccd01da983dd1237" + [[package]] name = "flate2" version = "1.0.34" @@ -1402,6 +1410,12 @@ dependencies = [ "seq-macro", ] +[[package]] +name = "general-sam" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b24ea6ac8cb305d7066d4c3e4587a5541f40124118367be09a10ea37898b8fc4" + [[package]] name = "generic-array" version = "0.14.7" @@ -1412,6 +1426,15 @@ dependencies = [ "version_check", ] +[[package]] +name = "generic-array" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "96512db27971c2c3eece70a1e106fbe6c87760234e31e8f7e5634912fe52794a" +dependencies = [ + "typenum", +] + [[package]] name = "getrandom" version = "0.2.15" @@ -1529,6 +1552,15 @@ dependencies = [ "ureq", ] +[[package]] +name = "html-escape" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d1ad449764d627e22bfd7cd5e8868264fc9236e07c752972b4080cd351cb476" +dependencies = [ + "utf8-width", +] + [[package]] name = "http" version = "1.1.0" @@ -1831,6 +1863,18 @@ version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jaggedarray" +version = "0.2.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d18d4d488e472f5e805edd8d7e710c255ff19511e3f5fbb0932932bee0e804d5" +dependencies = [ + "generic-array 1.1.0", + "num", + "tinyvec", + "typenum", +] + [[package]] name = "jpeg-decoder" version = "0.3.1" @@ -1846,6 +1890,61 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "kbnf" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8f8f3801253bf69215a17e0b509d813561ff113ab32506d0624391ca38e8511" +dependencies = [ + "ahash", + "displaydoc", + "fixedbitset-stack", + "general-sam", + "getrandom", + "jaggedarray", + "kbnf-regex-automata", + "kbnf-syntax", + "log", + "nom", + "nonmax", + "num", + "serde", + "string-interner", + "strum", + "thiserror", + "tinyvec", +] + +[[package]] +name = "kbnf-regex-automata" +version = "0.4.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "50a900e0a0e795744f6b18fe493500a1d773633972251f84f4be909edfb0d0f2" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax 0.8.5", +] + +[[package]] +name = "kbnf-syntax" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "abf70e6972a0f91801c1d8362527ea4a1449326a882a73186af16e47e2994c13" +dependencies = [ + "general-sam", + "kbnf-regex-automata", + "nom", + "parse-hyperlinks", + "regex-lite", + "regex-syntax 0.8.5", + "rustc-hash", + "serde", + "string-interner", + "thiserror", + "unescaper", +] + [[package]] name = "lazy_static" version = "1.5.0" @@ -2129,6 +2228,7 @@ name = "mistralrs-core" version = "0.3.1" dependencies = [ "accelerate-src", + "ahash", "akin", "anyhow", "as-any", @@ -2158,6 +2258,7 @@ dependencies = [ "indicatif", "intel-mkl-src", "itertools 0.13.0", + "kbnf", "lrtable", "minijinja", "minijinja-contrib", @@ -2351,6 +2452,12 @@ dependencies = [ "minimal-lexical", ] +[[package]] +name = "nonmax" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "610a5acd306ec67f907abe5567859a3c693fb9886eb1f012ab8f2a47bef3db51" + [[package]] name = "ntapi" version = "0.4.1" @@ -2370,6 +2477,20 @@ dependencies = [ "winapi", ] +[[package]] +name = "num" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" +dependencies = [ + "num-bigint", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", +] + [[package]] name = "num-bigint" version = "0.4.6" @@ -2399,6 +2520,17 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + [[package]] name = "num-rational" version = "0.4.2" @@ -2647,6 +2779,18 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "parse-hyperlinks" +version = "0.27.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92ac700657cc7a89620d1c28b310d594ff0491c11fd3dd1ae748ed5c5c8640e6" +dependencies = [ + "html-escape", + "nom", + "percent-encoding", + "thiserror", +] + [[package]] name = "parse-zoneinfo" version = "0.3.1" @@ -3133,6 +3277,12 @@ dependencies = [ "regex-syntax 0.8.5", ] +[[package]] +name = "regex-lite" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53a49587ad06b26609c52e423de037e7f57f20d53535d66e08c695f347df952a" + [[package]] name = "regex-syntax" version = "0.6.29" @@ -3672,6 +3822,17 @@ dependencies = [ "regex", ] +[[package]] +name = "string-interner" +version = "0.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c6a0d765f5807e98a091107bae0a56ea3799f66a5de47b2c84c94a39c09974e" +dependencies = [ + "cfg-if", + "hashbrown", + "serde", +] + [[package]] name = "strsim" version = "0.10.0" @@ -4199,6 +4360,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eeba86d422ce181a719445e51872fa30f1f7413b62becb52e95ec91aa262d85c" +[[package]] +name = "unescaper" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c878a167baa8afd137494101a688ef8c67125089ff2249284bd2b5f9bfedb815" +dependencies = [ + "thiserror", +] + [[package]] name = "unicase" version = "2.7.0" @@ -4304,6 +4474,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "utf8-width" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86bd8d4e895da8537e5315b8254664e6b769c4ff3db18321b297a1e7004392e3" + [[package]] name = "utf8parse" version = "0.2.2" diff --git a/examples/server/grammar.py b/examples/server/grammar.py new file mode 100644 index 000000000..3d3bbe2b0 --- /dev/null +++ b/examples/server/grammar.py @@ -0,0 +1,99 @@ +from openai import OpenAI + +client = OpenAI(api_key="foobar", base_url="http://localhost:1234/v1/") + +JSON_KBNF = ''' +(* JSON Grammar *) + +(* JSON text must contain a single JSON value *) +start = value ; + +(* A JSON value can be an object, array, string, number, true, false, or null *) +value ::= object + | array + | string + | number + | "true" + | "false" + | "null" ; + +(* A JSON object is a collection of key/value pairs enclosed in curly braces *) +object ::= "{" [ members ] "}" ; +members ::= pair { "," pair } ; +pair ::= string ":" value ; + +(* A JSON array is an ordered list of values enclosed in square brackets *) +array ::= "[" [ elements ] "]" ; +elements ::= value { "," value } ; + +(* A JSON string is a sequence of Unicode characters enclosed in double quotes *) +string ::= "\"" { character } "\"" ; +character ::= escape + | non_escape ; + +(* Escape sequences *) +escape ::= "\\" ( "\"" | "\\" | "/" | "b" | "f" | "n" | "r" | "t" | "u" hex hex hex hex ) ; +non_escape ::= ? any character except " or \ or control characters ? ; + +(* A JSON number is an integer or floating-point number *) +number ::= integer [ fraction ] [ exponent ] ; +integer ::= digit | "-" digit | "-" non_zero_digit { digit } | non_zero_digit { digit } ; +fraction ::= "." { digit } ; +exponent ::= ("e" | "E") [ "+" | "-" ] { digit } ; +digit ::= "0" | non_zero_digit ; +non_zero_digit ::= "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" ; + +(* Hexadecimal digits for unicode escape sequences *) +hex ::= digit | "A" | "B" | "C" | "D" | "E" | "F" ; + +''' + +EXPR_KBNF = ''' + +(* Grammar for Mathematical Expressions *) + +(* An expression can be a term or a sum/subtraction of terms *) +start ::= term { ("+" | "-") term } ; + +(* A term can be a factor or a product/division of factors *) +term ::= factor { ("*" | "/") factor } ; + +(* A factor can be a number, a variable, or a parenthesized expression *) +factor ::= number + | variable + | "(" start ")" ; + +(* A number is a sequence of digits, possibly with a decimal point *) +number ::= digit { digit } [ "." digit { digit } ] ; + +(* A variable is an identifier starting with a letter, possibly followed by letters or digits *) +variable ::= letter { letter | digit } ; + +(* Digits and letters *) +digit ::= "0" | "1" | "2" | "3" | "4" | "5" | "6" | "7" | "8" | "9" ; +letter ::= "a" | "b" | "c" | "d" | "e" | "f" | "g" | "h" | "i" | "j" + | "k" | "l" | "m" | "n" | "o" | "p" | "q" | "r" | "s" | "t" + | "u" | "v" | "w" | "x" | "y" | "z" + | "A" | "B" | "C" | "D" | "E" | "F" | "G" | "H" | "I" | "J" + | "K" | "L" | "M" | "N" | "O" | "P" | "Q" | "R" | "S" | "T" + | "U" | "V" | "W" | "X" | "Y" | "Z" ; + + +''' + +completion = client.chat.completions.create( + model="mistral", + messages=[ + { + "role": "user", + "content": "Write a mathematical expression.", + } + ], + max_tokens=256, + frequency_penalty=1.0, + top_p=0.1, + temperature=0, + extra_body={"grammar": {"type": "kbnf", "value": EXPR_KBNF}}, +) + +print(completion.choices[0].message.content) diff --git a/mistralrs-core/Cargo.toml b/mistralrs-core/Cargo.toml index 8a312edd2..d286759cf 100644 --- a/mistralrs-core/Cargo.toml +++ b/mistralrs-core/Cargo.toml @@ -78,6 +78,8 @@ regex = "1.10.6" safetensors = "0.4.5" serde_plain = "1.0.2" as-any = "0.3.1" +kbnf = "0.5.1" +ahash = "0.8.11" [features] pyo3_macros = ["pyo3"] diff --git a/mistralrs-core/src/engine/mod.rs b/mistralrs-core/src/engine/mod.rs index 04ab7e4bc..184dcfccc 100644 --- a/mistralrs-core/src/engine/mod.rs +++ b/mistralrs-core/src/engine/mod.rs @@ -1,3 +1,4 @@ +use anyhow::Context; use once_cell::sync::Lazy; use std::{ collections::HashMap, @@ -7,6 +8,7 @@ use std::{ }, time::{Instant, SystemTime, UNIX_EPOCH}, }; +use tokenizers::Tokenizer; use tokio::sync::{mpsc::Receiver, Mutex}; use crate::{ @@ -20,7 +22,7 @@ use crate::{ scheduler::{Scheduler, SchedulerOutput}, sequence::{SeqStepType, StopReason}, tools::{ToolCallingMatcher, ToolChoice}, - CompletionResponse, RequestMessage, Response, SchedulerConfig, DEBUG, + CompletionResponse, KbnfGrammar, RequestMessage, Response, SchedulerConfig, DEBUG, }; use rand::SeedableRng; use rand_isaac::Isaac64Rng; @@ -455,12 +457,20 @@ impl Engine { } } - fn build_sequence_recognizer(constraint: &Constraint) -> anyhow::Result { + fn build_sequence_recognizer( + constraint: &Constraint, + tokenizer: Option>, + ) -> anyhow::Result { let recognizer = match constraint { Constraint::Regex(rx) => { SequenceRecognizer::Regex(StackRecognizer::from(RecRx::from_rx(rx, None)?).into()) } Constraint::Yacc(cfg) => SequenceRecognizer::Cfg(CfgParser::from_yacc(cfg)?.into()), + Constraint::Kbnf(cfg) => SequenceRecognizer::Kbnf(KbnfGrammar::new( + cfg, + &*tokenizer + .context("Expected model to have a tokenizer, but using a KBNF grammar.")?, + )?), Constraint::None => SequenceRecognizer::None, }; Ok(recognizer) @@ -766,7 +776,10 @@ impl Engine { // Add sequences for response_index in 0..request.sampling_params.n_choices { - let recognizer = match Self::build_sequence_recognizer(&request.constraint) { + let recognizer = match Self::build_sequence_recognizer( + &request.constraint, + get_mut_arcmutex!(self.pipeline).tokenizer(), + ) { Ok(recognizer) => recognizer, Err(err) => { request diff --git a/mistralrs-core/src/kbnf.rs b/mistralrs-core/src/kbnf.rs new file mode 100644 index 000000000..edc0ec0da --- /dev/null +++ b/mistralrs-core/src/kbnf.rs @@ -0,0 +1,103 @@ +use ahash::AHashMap; +use anyhow::Result; +use candle_core::{DType, Device, Tensor}; +use kbnf::{ + engine_like::{AcceptTokenError, MaskLogitsError}, + AcceptTokenResult, Engine, EngineLike, Token, Vocabulary, +}; +use tokenizers::Tokenizer; + +pub struct KbnfGrammar { + engine: Engine, + vocab_size: usize, +} + +pub enum KbnfGrammarBias { + /// Token was accepted, it can be added to the sequence. No need for resampling. + Accepted, + /// Token was rejected. Resample with these new logits. + /// The token sampled with the bias can be added to the sequence. + Resample { new_logits: Tensor }, + /// Generation was finished, the token can be added to the sequence, + /// but no more generation is necessary. + FinishedGeneration, +} + +impl KbnfGrammar { + pub fn new(grammar: &str, tokenizer: &Tokenizer) -> Result { + let tokenizer_vocab = tokenizer.get_vocab(true); + let mut id_to_tok = AHashMap::new(); + let mut id_to_tok_str = AHashMap::new(); + for (tok_str, id) in tokenizer_vocab { + id_to_tok.insert(id, Token(tok_str.as_bytes().to_vec().into_boxed_slice())); + id_to_tok_str.insert(id, tok_str); + } + let vocab = Vocabulary::new(id_to_tok, id_to_tok_str)?; + Ok(Self { + engine: Engine::new(grammar, vocab)?, + vocab_size: tokenizer.get_vocab_size(true), + }) + } + + /// Compute the bias if this token were to be added. + /// If the token can be added + pub fn compute_bias_for( + &mut self, + tok: u32, + logits: &Tensor, + add_to_trie: bool, + ) -> candle_core::Result { + // Try to accept the new token + match self.engine.try_accept_new_token(tok) { + Ok(AcceptTokenResult::Ongoing) => { + // Token was accepted, no resampling needed + if add_to_trie { + self.engine.compute_allowed_token_ids(); + } + Ok(KbnfGrammarBias::Accepted) + } + Err(AcceptTokenError::Rejected) => { + if add_to_trie { + self.engine.compute_allowed_token_ids(); + } + let mut bias = vec![0f32; self.vocab_size]; + match self.engine.mask_logits(&mut bias) { + Ok(()) => { + let new_logits = (logits.to_device(&Device::Cpu)?.to_dtype(DType::F32)? + + Tensor::from_vec(bias, (self.vocab_size,), &Device::Cpu)?)?; + Ok(KbnfGrammarBias::Resample { new_logits }) + } + Err(MaskLogitsError::InvalidLogitsLength) => { + // This should really be unreachable. + candle_core::bail!("Invalid logits length {}", bias.len()) + } + } + } + Ok(AcceptTokenResult::Finished) | Err(AcceptTokenError::Finished) => { + Ok(KbnfGrammarBias::FinishedGeneration) + } + Err(AcceptTokenError::UnknownTokenID) => candle_core::bail!("Unknown token ID {tok}"), + } + } + + /// Add a token, also to the trie. + /// + /// This really should not fail as it should be called with the masked bias from `compute_bias_for`. + pub fn add_token(&mut self, tok: u32) -> candle_core::Result<()> { + // Try to accept the new token + match self.engine.try_accept_new_token(tok) { + Ok(AcceptTokenResult::Ongoing) => { + // Token was accepted, no resampling needed + self.engine.compute_allowed_token_ids(); + Ok(()) + } + Err(AcceptTokenError::Rejected) => { + candle_core::bail!("New token was rejected"); + } + Ok(AcceptTokenResult::Finished) | Err(AcceptTokenError::Finished) => { + candle_core::bail!("Generation was finished."); + } + Err(AcceptTokenError::UnknownTokenID) => candle_core::bail!("Unknown token ID {tok}"), + } + } +} diff --git a/mistralrs-core/src/lib.rs b/mistralrs-core/src/lib.rs index 59281e0e6..5ecaf1b87 100644 --- a/mistralrs-core/src/lib.rs +++ b/mistralrs-core/src/lib.rs @@ -26,6 +26,7 @@ mod aici; mod cuda; mod device_map; mod engine; +mod kbnf; mod lora; mod model_loader; mod ops; @@ -67,6 +68,7 @@ mod xlora_models; pub use amoe::{AnyMoeConfig, AnyMoeExpertType}; pub use device_map::{DeviceLayerMapMetadata, DeviceMapMetadata, LayerDeviceMapper}; pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER}; +pub use kbnf::{KbnfGrammar, KbnfGrammarBias}; pub use mistralrs_quant::IsqType; pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig}; pub use pipeline::{ diff --git a/mistralrs-core/src/pipeline/sampling.rs b/mistralrs-core/src/pipeline/sampling.rs index cab5c6f95..ba9878b58 100644 --- a/mistralrs-core/src/pipeline/sampling.rs +++ b/mistralrs-core/src/pipeline/sampling.rs @@ -8,6 +8,7 @@ use crate::{ prefix_cacher::PrefixCacheManager, sampler::Logprobs, sequence::{Sequence, SequenceRecognizer}, + KbnfGrammarBias, }; use super::Pipeline; @@ -324,76 +325,141 @@ pub async fn sample_sequence( )? }; - let bias_if_not_allowed = match &mut seq.recognizer { - SequenceRecognizer::Regex(ref mut rx) => { - get_bias_if_not_allowed!(seq.tok_trie, rx.as_mut(), first_lobprobs_response.token) - } - SequenceRecognizer::Cfg(ref mut cfg) => { - get_bias_if_not_allowed!(seq.tok_trie, cfg.as_mut(), first_lobprobs_response.token) - } - SequenceRecognizer::None => None, - }; - let second_logprobs_response = match bias_if_not_allowed { - Some(token_set) => { - let mut acc = vec![ - -f32::INFINITY; - seq.tok_trie - .as_ref() - .ok_or(candle_core::Error::Msg( - "TokTrie must be present in pipeline if bias is calculated".to_string() - ))? - .vocab_size() - ]; - token_set.apply_to(&mut acc); - let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?; - - let ctx_clone = seq.get_toks().to_vec(); - let rng_clone = rng.clone(); - let sampler = seq.sampler(); - if use_async_pool { - tokio_rayon::spawn(move || { - sampler.sample( - new_logits, - &ctx_clone, - return_logprobs, - rng_clone, - sample_speculative, + match seq.recognizer { + SequenceRecognizer::Cfg(_) | SequenceRecognizer::Regex(_) => { + let bias_if_not_allowed = match &mut seq.recognizer { + SequenceRecognizer::Regex(ref mut rx) => { + get_bias_if_not_allowed!( + seq.tok_trie, + rx.as_mut(), + first_lobprobs_response.token ) - }) - .await? - } else { - sampler.sample( - new_logits, - &ctx_clone, - return_logprobs, - rng_clone, - sample_speculative, - )? + } + SequenceRecognizer::Cfg(ref mut cfg) => { + get_bias_if_not_allowed!( + seq.tok_trie, + cfg.as_mut(), + first_lobprobs_response.token + ) + } + SequenceRecognizer::None | SequenceRecognizer::Kbnf(_) => None, + }; + let second_logprobs_response = match bias_if_not_allowed { + Some(token_set) => { + let mut acc = vec![ + -f32::INFINITY; + seq.tok_trie + .as_ref() + .ok_or(candle_core::Error::Msg( + "TokTrie must be present in pipeline if bias is calculated" + .to_string() + ))? + .vocab_size() + ]; + token_set.apply_to(&mut acc); + let new_logits = (logits + Tensor::from_slice(&acc, acc.len(), &Device::Cpu)?)?; + + let ctx_clone = seq.get_toks().to_vec(); + let rng_clone = rng.clone(); + let sampler = seq.sampler(); + if use_async_pool { + tokio_rayon::spawn(move || { + sampler.sample( + new_logits, + &ctx_clone, + return_logprobs, + rng_clone, + sample_speculative, + ) + }) + .await? + } else { + sampler.sample( + new_logits, + &ctx_clone, + return_logprobs, + rng_clone, + sample_speculative, + )? + } + } + None => first_lobprobs_response, + }; + + if add_to_trie && seq.tok_trie.is_some() { + match seq.recognizer { + SequenceRecognizer::Regex(ref mut rx) => { + seq.tok_trie + .as_ref() + .unwrap() + .append_token(rx.as_mut(), second_logprobs_response.token) + .map_err(candle_core::Error::msg)?; + } + SequenceRecognizer::Cfg(ref mut cfg) => { + seq.tok_trie + .as_ref() + .unwrap() + .append_token(cfg.as_mut(), second_logprobs_response.token) + .map_err(candle_core::Error::msg)?; + } + SequenceRecognizer::None | SequenceRecognizer::Kbnf(_) => {} + } } + Ok(second_logprobs_response) } - None => first_lobprobs_response, - }; - if add_to_trie && seq.tok_trie.is_some() { - match seq.recognizer { - SequenceRecognizer::Regex(ref mut rx) => { - seq.tok_trie - .as_ref() - .unwrap() - .append_token(rx.as_mut(), second_logprobs_response.token) - .map_err(candle_core::Error::msg)?; - } - SequenceRecognizer::Cfg(ref mut cfg) => { - seq.tok_trie - .as_ref() - .unwrap() - .append_token(cfg.as_mut(), second_logprobs_response.token) - .map_err(candle_core::Error::msg)?; + SequenceRecognizer::None => Ok(first_lobprobs_response), + + SequenceRecognizer::Kbnf(_) => { + let bias = { + let SequenceRecognizer::Kbnf(ref mut kbnf) = seq.recognizer else { + unreachable!() + }; + kbnf.compute_bias_for(first_lobprobs_response.token, &logits, add_to_trie)? + }; + match bias { + // If the token was accepted, it's added to the kbnf engine + KbnfGrammarBias::Accepted => Ok(first_lobprobs_response), + KbnfGrammarBias::Resample { new_logits } => { + let ctx_clone = seq.get_toks().to_vec(); + let rng_clone = rng.clone(); + let sampler = seq.sampler(); + let second_sampled = if use_async_pool { + tokio_rayon::spawn(move || { + sampler.sample( + new_logits, + &ctx_clone, + return_logprobs, + rng_clone, + sample_speculative, + ) + }) + .await? + } else { + sampler.sample( + new_logits, + &ctx_clone, + return_logprobs, + rng_clone, + sample_speculative, + )? + }; + + // Add to kbnf engine + if add_to_trie { + let SequenceRecognizer::Kbnf(ref mut kbnf) = seq.recognizer else { + unreachable!() + }; + kbnf.add_token(second_sampled.token)?; + } + Ok(second_sampled) + } + KbnfGrammarBias::FinishedGeneration => { + todo!() + } } - SequenceRecognizer::None => {} } } - Ok(second_logprobs_response) } #[derive(Clone)] diff --git a/mistralrs-core/src/pipeline/speculative.rs b/mistralrs-core/src/pipeline/speculative.rs index 19101a2e6..28912250c 100644 --- a/mistralrs-core/src/pipeline/speculative.rs +++ b/mistralrs-core/src/pipeline/speculative.rs @@ -601,6 +601,9 @@ impl Pipeline for SpeculativePipeline { .append_token(cfg.as_mut(), accepted.token) .map_err(candle_core::Error::msg)?; } + SequenceRecognizer::Kbnf(ref mut kbnf) => { + kbnf.add_token(accepted.token)?; + } SequenceRecognizer::None => {} } } diff --git a/mistralrs-core/src/request.rs b/mistralrs-core/src/request.rs index 4371afb1b..629dd3da9 100644 --- a/mistralrs-core/src/request.rs +++ b/mistralrs-core/src/request.rs @@ -17,6 +17,7 @@ use tokio::sync::mpsc::Sender; pub enum Constraint { Regex(String), Yacc(String), + Kbnf(String), None, } diff --git a/mistralrs-core/src/sequence.rs b/mistralrs-core/src/sequence.rs index c962ae90f..6b5fba9fc 100644 --- a/mistralrs-core/src/sequence.rs +++ b/mistralrs-core/src/sequence.rs @@ -15,7 +15,7 @@ use crate::{ response::CompletionChoice, tools::ToolCallingMatcher, CompletionChunkChoice, CompletionChunkResponse, CompletionResponse, ImageChoice, - ImageGenerationResponse, ImageGenerationResponseFormat, + ImageGenerationResponse, ImageGenerationResponseFormat, KbnfGrammar, }; use crate::{ get_mut_group, @@ -70,6 +70,7 @@ pub enum SequenceState { pub enum SequenceRecognizer { Regex(Box>), Cfg(Box), + Kbnf(KbnfGrammar), None, } diff --git a/mistralrs-pyo3/src/lib.rs b/mistralrs-pyo3/src/lib.rs index 0eae3a3c6..65eedf325 100644 --- a/mistralrs-pyo3/src/lib.rs +++ b/mistralrs-pyo3/src/lib.rs @@ -653,6 +653,13 @@ impl Runner { )); } Constraint::Yacc(request.grammar.as_ref().unwrap().clone()) + } else if request.grammar_type == Some("kbnf".to_string()) { + if request.grammar.is_none() { + return Err(PyApiErr::from( + "Grammar type is specified but not grammar text", + )); + } + Constraint::Kbnf(request.grammar.as_ref().unwrap().clone()) } else if request.grammar_type.is_some() { return Err(PyApiErr::from( "Grammar type is specified but is not `regex` or `yacc`", @@ -915,6 +922,13 @@ impl Runner { )); } Constraint::Yacc(request.grammar.as_ref().unwrap().clone()) + } else if request.grammar_type == Some("kbnf".to_string()) { + if request.grammar.is_none() { + return Err(PyApiErr::from( + "Grammar type is specified but not grammar text", + )); + } + Constraint::Kbnf(request.grammar.as_ref().unwrap().clone()) } else if request.grammar_type.is_some() { return Err(PyApiErr::from( "Grammar type is specified but is not `regex` or `yacc`", diff --git a/mistralrs-server/src/chat_completion.rs b/mistralrs-server/src/chat_completion.rs index 5ba1745de..8a741db04 100644 --- a/mistralrs-server/src/chat_completion.rs +++ b/mistralrs-server/src/chat_completion.rs @@ -322,6 +322,7 @@ async fn parse_request( constraint: match oairequest.grammar { Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), Some(Grammar::Regex(regex)) => Constraint::Regex(regex), + Some(Grammar::Kbnf(kbnf)) => Constraint::Kbnf(kbnf), None => Constraint::None, }, adapters: oairequest.adapters, diff --git a/mistralrs-server/src/completions.rs b/mistralrs-server/src/completions.rs index c73a2e50b..4773804cb 100644 --- a/mistralrs-server/src/completions.rs +++ b/mistralrs-server/src/completions.rs @@ -202,6 +202,7 @@ fn parse_request( constraint: match oairequest.grammar { Some(Grammar::Yacc(yacc)) => Constraint::Yacc(yacc), Some(Grammar::Regex(regex)) => Constraint::Regex(regex), + Some(Grammar::Kbnf(kbnf)) => Constraint::Kbnf(kbnf), None => Constraint::None, }, adapters: oairequest.adapters, diff --git a/mistralrs-server/src/openai.rs b/mistralrs-server/src/openai.rs index 7671d2343..43689d60f 100644 --- a/mistralrs-server/src/openai.rs +++ b/mistralrs-server/src/openai.rs @@ -74,6 +74,8 @@ pub enum Grammar { Regex(String), #[serde(rename = "yacc")] Yacc(String), + #[serde(rename = "kbnf")] + Kbnf(String), } #[derive(Debug, Clone, Deserialize, Serialize, ToSchema)]