From ac512fd6a71ecef1de1fad8b40545484272bd7a6 Mon Sep 17 00:00:00 2001 From: Michal Moskal Date: Thu, 19 Dec 2024 15:14:37 +0000 Subject: [PATCH] use the new ParserFactory from llg with slicer optimization --- Cargo.lock | 2 +- llgtrt/src/config.rs | 21 ++++++++- llgtrt/src/constraint_mgr.rs | 81 -------------------------------- llgtrt/src/lib.rs | 1 - llgtrt/src/routes/completions.rs | 15 +++--- llgtrt/src/startup.rs | 18 +++++-- llgtrt/src/state.rs | 5 +- llguidance | 2 +- 8 files changed, 49 insertions(+), 96 deletions(-) delete mode 100644 llgtrt/src/constraint_mgr.rs diff --git a/Cargo.lock b/Cargo.lock index b8d52c5..99d4fc4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1,6 +1,6 @@ # This file is automatically @generated by Cargo. # It is not intended for manual editing. -version = 3 +version = 4 [[package]] name = "addr2line" diff --git a/llgtrt/src/config.rs b/llgtrt/src/config.rs index 1d7bc87..0b97a2e 100644 --- a/llgtrt/src/config.rs +++ b/llgtrt/src/config.rs @@ -1,7 +1,8 @@ use clap::Parser; +use llguidance::api::ParserLimits; use serde::{Deserialize, Serialize}; -use crate::{constraint_mgr::LlgConfig, tokenizer::TokenizerConfig}; +use crate::tokenizer::TokenizerConfig; const CONFIG_INFO: &str = include_str!("config_info.json"); pub fn config_info() -> serde_json::Value { @@ -71,6 +72,24 @@ pub struct LlgTrtConfig { pub llguidance: LlgConfig, } +#[derive(Clone, Serialize, Deserialize, Debug)] +pub struct LlgConfig { + /// Override any of the parser limits. + pub limits: ParserLimits, + + /// Log level which goes to stderr. In-memory logs per-sequence are managed by ConstraintInit.log_level. + pub log_level: u32, +} + +impl Default for LlgConfig { + fn default() -> Self { + Self { + limits: ParserLimits::default(), + log_level: 1, + } + } +} + #[derive(Parser, Debug, Serialize, Deserialize)] pub struct CliConfig { /// Host to bind to diff --git a/llgtrt/src/constraint_mgr.rs b/llgtrt/src/constraint_mgr.rs deleted file mode 100644 index 8aa971c..0000000 --- a/llgtrt/src/constraint_mgr.rs +++ /dev/null @@ -1,81 +0,0 @@ -use crate::routes::api_ext::LlgLogLevel; -use anyhow::Result; -use llguidance::{ - api::{ParserLimits, TopLevelGrammar}, - Constraint, Logger, TokenParser, -}; -use serde::{Deserialize, Serialize}; -use toktrie::{InferenceCapabilities, TokEnv}; - -#[derive(Clone, Serialize, Deserialize, Debug)] -pub struct LlgConfig { - /// Override any of the parser limits. - pub limits: ParserLimits, - - /// Log level which goes to stderr. In-memory logs per-sequence are managed by ConstraintInit.log_level. - pub log_level: u32, -} - -impl Default for LlgConfig { - fn default() -> Self { - Self { - limits: ParserLimits::default(), - log_level: 1, - } - } -} - -pub struct ConstraintInit { - pub grammar: TopLevelGrammar, - pub is_chat: bool, - pub log_level: LlgLogLevel, -} - -pub struct ConstraintMgr { - tok_env: TokEnv, - chat_tok_env: TokEnv, - inference_caps: InferenceCapabilities, - parser_limits: ParserLimits, - log_stderr_level: u32, -} - -impl ConstraintMgr { - pub fn new(tok_env: TokEnv, chat_tok_env: TokEnv, config: &LlgConfig) -> Result { - Ok(ConstraintMgr { - tok_env, - chat_tok_env, - inference_caps: InferenceCapabilities { - ff_tokens: false, // not supported yet - backtrack: false, // unlikely - ..Default::default() - }, - parser_limits: config.limits.clone(), - log_stderr_level: config.log_level, - }) - } - - pub fn new_constraint(&self, init: ConstraintInit) -> Result { - let parser = TokenParser::from_llguidance_json( - if init.is_chat { - self.chat_tok_env.clone() - } else { - self.tok_env.clone() - }, - init.grammar, - Logger::new(init.log_level.to_log_level(), self.log_stderr_level), - self.inference_caps.clone(), - self.parser_limits.clone(), - vec![], - )?; - let mut constraint = Constraint::new(parser); - if init.log_level.has_json() { - constraint.log_json_progress = true; - } - Ok(constraint) - } - - #[allow(dead_code)] - pub fn tok_trie(&self) -> &toktrie::TokTrie { - self.tok_env.tok_trie() - } -} diff --git a/llgtrt/src/lib.rs b/llgtrt/src/lib.rs index 016c126..a9ebfac 100644 --- a/llgtrt/src/lib.rs +++ b/llgtrt/src/lib.rs @@ -7,5 +7,4 @@ pub mod startup; pub mod state; mod async_exec; pub mod logging; -mod constraint_mgr; pub mod jsonutil; \ No newline at end of file diff --git a/llgtrt/src/routes/completions.rs b/llgtrt/src/routes/completions.rs index a56f603..a18d0f8 100644 --- a/llgtrt/src/routes/completions.rs +++ b/llgtrt/src/routes/completions.rs @@ -19,7 +19,6 @@ use uuid::Uuid; use crate::async_exec::{map_finish_reason, AsyncExecutor, StepResults}; use crate::chat::ChatParams; -use crate::constraint_mgr::ConstraintInit; use crate::error::AppError; use crate::routes::api_ext::{tools_to_schema, LlgLogLevel}; use crate::routes::openai::{JsonSchemaOptions, ResponseFormat, ToolChoice}; @@ -214,11 +213,15 @@ async fn mk_req_info( let llg = if let Some(grm) = llg_grammar(params)? { // println!("grammar: {}", serde_json::to_string(&grm).unwrap()); - let mut llg = app_state.constraint_mgr.new_constraint(ConstraintInit { - grammar: grm, - is_chat, - log_level: params.llg_log_level, - })?; + let parser = app_state + .parser_factory + .create_parser_ext(grm, params.llg_log_level.to_log_level())?; + + let mut llg = Constraint::new(parser); + + if params.llg_log_level.has_json() { + llg.log_json_progress = true; + } // temperature handled by logits processing - this has to be 1.0 // to avoid double-application of temperature diff --git a/llgtrt/src/startup.rs b/llgtrt/src/startup.rs index 810f67d..82a75b1 100644 --- a/llgtrt/src/startup.rs +++ b/llgtrt/src/startup.rs @@ -7,11 +7,13 @@ use axum::middleware::{self, Next}; use axum::response::Response; use axum::routing::{get, post}; use axum::Router; +use llguidance::earley::SlicedBiasComputer; +use llguidance::ParserFactory; +use toktrie::InferenceCapabilities; use trtllm_rs::{ClientReqId, ExecutorInit, RequestInit, RequestParams}; use crate::async_exec::AsyncExecutor; use crate::config::{config_info, CliConfig, LlgTrtConfig}; -use crate::constraint_mgr::ConstraintMgr; use crate::jsonutil::json5_to_string; use crate::state::AppState; use crate::{jsonutil, routes}; @@ -146,7 +148,17 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { // we only get here on rank 0 - let constraint_mgr = ConstraintMgr::new(tok_env.clone(), tok_env.clone(), &config.llguidance)?; + let mut parser_factory = ParserFactory::new( + &tok_env, + InferenceCapabilities { + ff_tokens: false, // not supported yet + backtrack: false, // unlikely + ..Default::default() + }, + &SlicedBiasComputer::general_slices(), + ); + *parser_factory.limits_mut() = config.llguidance.limits.clone(); + parser_factory.set_stderr_log_level(config.llguidance.log_level); if let Some(t) = config.tokenizer.json_start_token.as_ref() { ensure!( @@ -168,7 +180,7 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { tok_env, next_client_req_id: std::sync::atomic::AtomicUsize::new(1000), chat_builder, - constraint_mgr, + parser_factory, }; // warmup request diff --git a/llgtrt/src/state.rs b/llgtrt/src/state.rs index 26a3bcd..3105429 100644 --- a/llgtrt/src/state.rs +++ b/llgtrt/src/state.rs @@ -1,4 +1,5 @@ -use crate::{chat::ChatBuilder, constraint_mgr::ConstraintMgr}; +use crate::chat::ChatBuilder; +use llguidance::ParserFactory; use toktrie::{TokEnv, TokenId}; // there's generally an Arc() around this @@ -10,7 +11,7 @@ pub struct AppState { pub json_start_token_name: Option, pub next_client_req_id: std::sync::atomic::AtomicUsize, pub chat_builder: ChatBuilder, - pub constraint_mgr: ConstraintMgr, + pub parser_factory: ParserFactory, } impl AppState { diff --git a/llguidance b/llguidance index adfac17..76c01cd 160000 --- a/llguidance +++ b/llguidance @@ -1 +1 @@ -Subproject commit adfac174fe40471bb719eab06db9e6f64a43e435 +Subproject commit 76c01cd9038e3bc7ceb8196c08a7b290ee4bc078