Skip to content

Commit

Permalink
use the new ParserFactory from llg with slicer optimization
Browse files Browse the repository at this point in the history
  • Loading branch information
mmoskal committed Dec 19, 2024
1 parent caaee2a commit ac512fd
Show file tree
Hide file tree
Showing 8 changed files with 49 additions and 96 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

21 changes: 20 additions & 1 deletion llgtrt/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use clap::Parser;
use llguidance::api::ParserLimits;
use serde::{Deserialize, Serialize};

use crate::{constraint_mgr::LlgConfig, tokenizer::TokenizerConfig};
use crate::tokenizer::TokenizerConfig;

const CONFIG_INFO: &str = include_str!("config_info.json");
pub fn config_info() -> serde_json::Value {
Expand Down Expand Up @@ -71,6 +72,24 @@ pub struct LlgTrtConfig {
pub llguidance: LlgConfig,
}

#[derive(Clone, Serialize, Deserialize, Debug)]
pub struct LlgConfig {
/// Override any of the parser limits.
pub limits: ParserLimits,

/// Log level which goes to stderr. In-memory logs per-sequence are managed by ConstraintInit.log_level.
pub log_level: u32,
}

impl Default for LlgConfig {
fn default() -> Self {
Self {
limits: ParserLimits::default(),
log_level: 1,
}
}
}

#[derive(Parser, Debug, Serialize, Deserialize)]
pub struct CliConfig {
/// Host to bind to
Expand Down
81 changes: 0 additions & 81 deletions llgtrt/src/constraint_mgr.rs

This file was deleted.

1 change: 0 additions & 1 deletion llgtrt/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,4 @@ pub mod startup;
pub mod state;
mod async_exec;
pub mod logging;
mod constraint_mgr;
pub mod jsonutil;
15 changes: 9 additions & 6 deletions llgtrt/src/routes/completions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use uuid::Uuid;

use crate::async_exec::{map_finish_reason, AsyncExecutor, StepResults};
use crate::chat::ChatParams;
use crate::constraint_mgr::ConstraintInit;
use crate::error::AppError;
use crate::routes::api_ext::{tools_to_schema, LlgLogLevel};
use crate::routes::openai::{JsonSchemaOptions, ResponseFormat, ToolChoice};
Expand Down Expand Up @@ -214,11 +213,15 @@ async fn mk_req_info(

let llg = if let Some(grm) = llg_grammar(params)? {
// println!("grammar: {}", serde_json::to_string(&grm).unwrap());
let mut llg = app_state.constraint_mgr.new_constraint(ConstraintInit {
grammar: grm,
is_chat,
log_level: params.llg_log_level,
})?;
let parser = app_state
.parser_factory
.create_parser_ext(grm, params.llg_log_level.to_log_level())?;

let mut llg = Constraint::new(parser);

if params.llg_log_level.has_json() {
llg.log_json_progress = true;
}

// temperature handled by logits processing - this has to be 1.0
// to avoid double-application of temperature
Expand Down
18 changes: 15 additions & 3 deletions llgtrt/src/startup.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ use axum::middleware::{self, Next};
use axum::response::Response;
use axum::routing::{get, post};
use axum::Router;
use llguidance::earley::SlicedBiasComputer;
use llguidance::ParserFactory;
use toktrie::InferenceCapabilities;
use trtllm_rs::{ClientReqId, ExecutorInit, RequestInit, RequestParams};

use crate::async_exec::AsyncExecutor;
use crate::config::{config_info, CliConfig, LlgTrtConfig};
use crate::constraint_mgr::ConstraintMgr;
use crate::jsonutil::json5_to_string;
use crate::state::AppState;
use crate::{jsonutil, routes};
Expand Down Expand Up @@ -146,7 +148,17 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> {

// we only get here on rank 0

let constraint_mgr = ConstraintMgr::new(tok_env.clone(), tok_env.clone(), &config.llguidance)?;
let mut parser_factory = ParserFactory::new(
&tok_env,
InferenceCapabilities {
ff_tokens: false, // not supported yet
backtrack: false, // unlikely
..Default::default()
},
&SlicedBiasComputer::general_slices(),
);
*parser_factory.limits_mut() = config.llguidance.limits.clone();
parser_factory.set_stderr_log_level(config.llguidance.log_level);

if let Some(t) = config.tokenizer.json_start_token.as_ref() {
ensure!(
Expand All @@ -168,7 +180,7 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> {
tok_env,
next_client_req_id: std::sync::atomic::AtomicUsize::new(1000),
chat_builder,
constraint_mgr,
parser_factory,
};

// warmup request
Expand Down
5 changes: 3 additions & 2 deletions llgtrt/src/state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use crate::{chat::ChatBuilder, constraint_mgr::ConstraintMgr};
use crate::chat::ChatBuilder;
use llguidance::ParserFactory;
use toktrie::{TokEnv, TokenId};

// there's generally an Arc() around this
Expand All @@ -10,7 +11,7 @@ pub struct AppState {
pub json_start_token_name: Option<String>,
pub next_client_req_id: std::sync::atomic::AtomicUsize,
pub chat_builder: ChatBuilder,
pub constraint_mgr: ConstraintMgr,
pub parser_factory: ParserFactory,
}

impl AppState {
Expand Down
2 changes: 1 addition & 1 deletion llguidance

0 comments on commit ac512fd

Please sign in to comment.