From dd689247e49f5acbfe9b59f624c130b94094fef4 Mon Sep 17 00:00:00 2001 From: Jeffrey Tang <810895+jeffreyftang@users.noreply.github.com> Date: Tue, 20 Feb 2024 17:20:29 -0600 Subject: [PATCH] enh: JSON schema for guided generation now optionally respects field order (#264) --- clients/python/lorax/types.py | 4 ++-- launcher/Cargo.toml | 2 +- router/Cargo.toml | 2 +- server/lorax_server/utils/logits_process.py | 1 + server/tests/models/test_causal_lm.py | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 8ad776c70..8dd024c7c 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -1,6 +1,6 @@ from enum import Enum from pydantic import BaseModel, validator, Field, ConfigDict -from typing import Optional, List, Dict, Any +from typing import Optional, List, Dict, Any, OrderedDict, Union from lorax.errors import ValidationError @@ -64,7 +64,7 @@ class ResponseFormat(BaseModel): model_config = ConfigDict(use_enum_values=True) type: ResponseFormatType - schema_spec: Dict[str, Any] = Field(alias="schema") + schema_spec: Union[Dict[str, Any], OrderedDict] = Field(alias="schema") class Parameters(BaseModel): diff --git a/launcher/Cargo.toml b/launcher/Cargo.toml index 803cc31ce..9053b1313 100644 --- a/launcher/Cargo.toml +++ b/launcher/Cargo.toml @@ -11,7 +11,7 @@ clap = { version = "4.1.4", features = ["derive", "env"] } ctrlc = { version = "3.2.5", features = ["termination"] } nix = "0.26.2" serde = { version = "1.0.152", features = ["derive"] } -serde_json = "1.0.93" +serde_json = { version = "1.0.93", features = ["preserve_order"] } tracing = "0.1.37" tracing-subscriber = { version = "0.3.16", features = ["json", "env-filter"] } diff --git a/router/Cargo.toml b/router/Cargo.toml index b6c750280..de69f34b4 100644 --- a/router/Cargo.toml +++ b/router/Cargo.toml @@ -30,7 +30,7 @@ opentelemetry-otlp = "0.12.0" rand = "0.8.5" reqwest = { version = "0.11.14", features = [] } serde = "1.0.152" -serde_json = "1.0.93" +serde_json = { version = "1.0.93", features = ["preserve_order"] } thiserror = "1.0.38" tokenizers = "0.13.4" tokio = { version = "1.25.0", features = ["rt", "rt-multi-thread", "parking_lot", "signal", "sync"] } diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index c5f58a274..a803b31f0 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -488,6 +488,7 @@ def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase): self.tokenizer = self.adapt_tokenizer(tokenizer) regex_string = build_regex_from_object(schema) + regex_string = '[\\n ]*' + regex_string # Hack to allow preceding whitespace self.fsm = RegexFSM(regex_string, tokenizer) self.fsm_state = FSMState(0) diff --git a/server/tests/models/test_causal_lm.py b/server/tests/models/test_causal_lm.py index e81a780cd..005eaabdc 100644 --- a/server/tests/models/test_causal_lm.py +++ b/server/tests/models/test_causal_lm.py @@ -145,7 +145,7 @@ def test_causal_lm_batch_type(default_causal_lm): @pytest.mark.parametrize("causal_lm_batch, generated_token_id", [ ("default_causal_lm_batch", 13), - ("schema_constrained_causal_lm_batch", 90), + ("schema_constrained_causal_lm_batch", 198), ]) def test_causal_lm_generate_token(default_causal_lm, causal_lm_batch, generated_token_id, request): causal_lm_batch = request.getfixturevalue(causal_lm_batch)