diff --git a/Cargo.lock b/Cargo.lock index a627b5b..0d0fdde 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1424,7 +1424,6 @@ dependencies = [ "memo-map", "self_cell", "serde", - "serde_json", ] [[package]] diff --git a/llgtrt/Cargo.toml b/llgtrt/Cargo.toml index db7a2c2..3b7ea37 100644 --- a/llgtrt/Cargo.toml +++ b/llgtrt/Cargo.toml @@ -21,5 +21,5 @@ rand = "0.8.5" llguidance_parser = { path = "../llguidance/parser" } rayon = "1.10.0" futures-core = "0.3.30" -minijinja = { version = "2.3.1", features = ["preserve_order", "json", "loop_controls", "loader"] } +minijinja = { version = "2.3.1", features = ["preserve_order", "loop_controls", "loader"] } chrono = "0.4.38" diff --git a/llgtrt/src/chat.rs b/llgtrt/src/chat.rs index e3a8eea..7c90278 100644 --- a/llgtrt/src/chat.rs +++ b/llgtrt/src/chat.rs @@ -3,9 +3,8 @@ use crate::{ tokenizer::TokenizerConfig, }; use anyhow::anyhow; -use minijinja::Environment; +use minijinja::{value::Kwargs, Environment, Error, ErrorKind, Value}; use serde::{Deserialize, Serialize}; -use serde_json::Value; const DEFAULT_TEMPLATE: &str = r#"{{- bos_token }} {%- for message in messages %} @@ -49,7 +48,7 @@ fn date_string() -> String { chrono::Utc::now().format("%e %B %Y").to_string() } -fn remove_null(v: &mut Value) { +fn remove_null(v: &mut serde_json::Value) { if let Some(map) = v.as_object_mut() { for (_, v) in map.iter_mut() { remove_null(v); @@ -62,6 +61,28 @@ fn remove_null(v: &mut Value) { } } +fn tojson(value: Value, args: Kwargs) -> Result { + let indent = match args.get::("indent") { + Ok(val) => val, + Err(_) => 4, + }; + args.assert_all_used()?; + let mut out = Vec::::new(); + let indentation = " ".repeat(indent); + let formatter = serde_json::ser::PrettyFormatter::with_indent(indentation.as_bytes()); + let mut s = serde_json::Serializer::with_formatter(&mut out, formatter); + let v = serde::Serialize::serialize(&value, &mut s) + .map(|_| unsafe { String::from_utf8_unchecked(out) }) + .map_err(|err| { + Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err) + })?; + Ok(Value::from_safe_string(v)) +} + +fn strftime_now(format: &str) -> String { + chrono::Utc::now().format(format).to_string() +} + impl ChatBuilder { pub fn new(config: &TokenizerConfig) -> anyhow::Result { let default_context = TemplateContext { @@ -83,6 +104,12 @@ impl ChatBuilder { // https://github.com/huggingface/transformers/blob/e50bf61decf741c6d59e4ba633b7392712673bda/src/transformers/utils/chat_template_utils.py#L423 env.set_lstrip_blocks(true); env.set_trim_blocks(true); + env.add_function("raise_exception", |msg: String| { + let e = minijinja::Error::new(minijinja::ErrorKind::InvalidOperation, msg); + Err::(e) + }); + env.add_function("strftime_now", strftime_now); + env.add_filter("tojson", tojson); let template = config .chat_template .clone() diff --git a/scripts/test-infer.sh b/scripts/test-infer.sh index 1d3a14b..92761c3 100755 --- a/scripts/test-infer.sh +++ b/scripts/test-infer.sh @@ -86,6 +86,18 @@ curl -X POST "${TRT_API_BASE}chat/completions" \ "temperature": 0.7 }' ;; + + exn) + curl -X POST "${TRT_API_BASE}chat/completions" \ + -H "Content-Type: application/json" -v \ + -d '{ + "model": "model", + "messages": [ + {"role": "assistant", "content": "What do you need?", "tool_calls": [{}, {}]} + ], + "temperature": 0.7 + }' + ;; tools) curl -X POST "${TRT_API_BASE}chat/completions" \ @@ -108,7 +120,7 @@ curl -X POST "${TRT_API_BASE}chat/completions" \ "type": "function", "function": { "name": "weather", - "description": "Get the weather for a location", + "description": "Get the weather for a ", "strict": true, "parameters": { "type": "object",