Skip to content

Commit

Permalink
hf transformers chat template compat: tojson, strftime_now, raise_exc…
Browse files Browse the repository at this point in the history
…eption
  • Loading branch information
mmoskal committed Oct 28, 2024
1 parent f5b377a commit c08bad5
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 6 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion llgtrt/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ rand = "0.8.5"
llguidance_parser = { path = "../llguidance/parser" }
rayon = "1.10.0"
futures-core = "0.3.30"
minijinja = { version = "2.3.1", features = ["preserve_order", "json", "loop_controls", "loader"] }
minijinja = { version = "2.3.1", features = ["preserve_order", "loop_controls", "loader"] }
chrono = "0.4.38"
33 changes: 30 additions & 3 deletions llgtrt/src/chat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ use crate::{
tokenizer::TokenizerConfig,
};
use anyhow::anyhow;
use minijinja::Environment;
use minijinja::{value::Kwargs, Environment, Error, ErrorKind, Value};
use serde::{Deserialize, Serialize};
use serde_json::Value;

const DEFAULT_TEMPLATE: &str = r#"{{- bos_token }}
{%- for message in messages %}
Expand Down Expand Up @@ -49,7 +48,7 @@ fn date_string() -> String {
chrono::Utc::now().format("%e %B %Y").to_string()
}

fn remove_null(v: &mut Value) {
fn remove_null(v: &mut serde_json::Value) {
if let Some(map) = v.as_object_mut() {
for (_, v) in map.iter_mut() {
remove_null(v);
Expand All @@ -62,6 +61,28 @@ fn remove_null(v: &mut Value) {
}
}

fn tojson(value: Value, args: Kwargs) -> Result<Value, Error> {
let indent = match args.get::<usize>("indent") {
Ok(val) => val,
Err(_) => 4,
};
args.assert_all_used()?;
let mut out = Vec::<u8>::new();
let indentation = " ".repeat(indent);
let formatter = serde_json::ser::PrettyFormatter::with_indent(indentation.as_bytes());
let mut s = serde_json::Serializer::with_formatter(&mut out, formatter);
let v = serde::Serialize::serialize(&value, &mut s)
.map(|_| unsafe { String::from_utf8_unchecked(out) })
.map_err(|err| {
Error::new(ErrorKind::InvalidOperation, "cannot serialize to JSON").with_source(err)
})?;
Ok(Value::from_safe_string(v))
}

fn strftime_now(format: &str) -> String {
chrono::Utc::now().format(format).to_string()
}

impl ChatBuilder {
pub fn new(config: &TokenizerConfig) -> anyhow::Result<Self> {
let default_context = TemplateContext {
Expand All @@ -83,6 +104,12 @@ impl ChatBuilder {
// https://github.com/huggingface/transformers/blob/e50bf61decf741c6d59e4ba633b7392712673bda/src/transformers/utils/chat_template_utils.py#L423
env.set_lstrip_blocks(true);
env.set_trim_blocks(true);
env.add_function("raise_exception", |msg: String| {
let e = minijinja::Error::new(minijinja::ErrorKind::InvalidOperation, msg);
Err::<minijinja::Value, _>(e)
});
env.add_function("strftime_now", strftime_now);
env.add_filter("tojson", tojson);
let template = config
.chat_template
.clone()
Expand Down
14 changes: 13 additions & 1 deletion scripts/test-infer.sh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,18 @@ curl -X POST "${TRT_API_BASE}chat/completions" \
"temperature": 0.7
}'
;;

exn)
curl -X POST "${TRT_API_BASE}chat/completions" \
-H "Content-Type: application/json" -v \
-d '{
"model": "model",
"messages": [
{"role": "assistant", "content": "What do you need?", "tool_calls": [{}, {}]}
],
"temperature": 0.7
}'
;;

tools)
curl -X POST "${TRT_API_BASE}chat/completions" \
Expand All @@ -108,7 +120,7 @@ curl -X POST "${TRT_API_BASE}chat/completions" \
"type": "function",
"function": {
"name": "weather",
"description": "Get the weather for a location",
"description": "Get the weather for a <location>",
"strict": true,
"parameters": {
"type": "object",
Expand Down

0 comments on commit c08bad5

Please sign in to comment.