Skip to content

Commit

Permalink
Modify resolution strategy to support resolving values internal to th…
Browse files Browse the repository at this point in the history
…e schema
  • Loading branch information
mattkindy authored and brandonwillard committed Oct 3, 2023
1 parent 5af26cb commit 4ce76ff
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 15 deletions.
73 changes: 58 additions & 15 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import itertools
import json
import re
from typing import Dict
from typing import Callable, Dict

STRING_INNER = r'(?:[^"\\\x00-\x1f\x7f-\x9f]|\\.)'
STRING = f'"{STRING_INNER}*"'
Expand Down Expand Up @@ -43,6 +43,45 @@ def build_regex_from_schema(schema: str):
return regex


def _ref_resolver(schema: Dict) -> Callable[[str], Dict]:
cache: Dict[str, Dict] = dict()

if "$id" in schema:
cache[schema["$id"]] = schema

if "$defs" in schema:
for definition, annotation in schema["$defs"].items():
cache[f"#/$defs/{definition}"] = annotation

if "$id" in annotation:
cache[annotation["$id"]] = annotation

def resolver(reference: str) -> Dict:
"""Resolve a $ref reference in the context of the top-level schema."""

if reference in cache:
return cache[reference]

path = reference.split("/")

# Navigate through the top-level schema based on the path
subschema = schema

if path[0] != "#":
raise ValueError(f"Unable to resolve reference: {reference}")

for step in path[1:]:
if step in subschema:
subschema = subschema[step]
else:
raise ValueError(f"Unable to resolve reference: {reference}")

cache[reference] = subschema
return subschema

return resolver


def build_schedule_from_schema(schema: str):
"""Turn a JSON schema into a regex that matches any JSON object that follows
this schema.
Expand Down Expand Up @@ -73,13 +112,7 @@ def build_schedule_from_schema(schema: str):
"""
schema = json.loads(schema)

# Find object definitions in the schema, if any
definitions = {}
if "$defs" in schema:
for definition, annotation in schema["$defs"].items():
definitions[f"#/$defs/{definition}"] = annotation

schema = expand_json_schema(schema, definitions)
schema = expand_json_schema(schema, resolver=_ref_resolver(schema))
schedule = build_schedule_from_instance(schema)

# Concatenate adjacent strings
Expand All @@ -92,20 +125,26 @@ def build_schedule_from_schema(schema: str):
return reduced_schedule


def expand_json_schema(raw_schema: Dict, definitions: Dict):
def expand_json_schema(
raw_schema: Dict,
resolver: Callable[[str], Dict],
):
"""Replace references by their value in the JSON Schema.
This recursively follows the references to other schemas in case
of nested models. Other schemas are stored under the "definitions"
key in the schema of the top-level model.
of nested models. Other schemas that may exist at a higher level
within the overall schema may be referenced via the `$ref` keyword
according to the JSON Schema specification.
Parameters
---------
raw_schema
The raw JSON schema as a Python dictionary, possibly with definitions
and references.
definitions
The currently known definitions.
resolver
A function that takes a reference and returns the corresponding schema
or subschema from the currently scoped top-level schema.
Returns
-------
Expand All @@ -116,16 +155,20 @@ def expand_json_schema(raw_schema: Dict, definitions: Dict):
expanded_properties = {}

if "properties" in raw_schema:
if "$id" in raw_schema:
# see https://json-schema.org/understanding-json-schema/structuring#bundling
resolver = _ref_resolver(raw_schema)

for name, value in raw_schema["properties"].items():
if "$ref" in value: # if item is a single element
expanded_properties[name] = expand_json_schema(
definitions[value["$ref"]], definitions
resolver(value["$ref"]), resolver
)
elif "type" in value and value["type"] == "array": # if item is a list
expanded_properties[name] = value
if "$ref" in value["items"]:
expanded_properties[name]["items"] = expand_json_schema(
definitions[value["items"]["$ref"]], definitions
resolver(value["items"]["$ref"]), resolver
)
else:
expanded_properties[name]["items"] = value["items"]
Expand Down
108 changes: 108 additions & 0 deletions tests/text/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,114 @@ def test_json_schema():
]


def test_json_schema_with_property_ref():
schema = """{
"title": "User",
"type": "object",
"properties": {
"user_id": {"title": "User Id", "type": "integer"},
"name": {"title": "Name", "type": "string"},
"a": {"$ref": "#/properties/name"},
"b": {"$ref": "#/properties/name"},
"c": {"$ref": "#/properties/name"}
},
"required": ["user_id", "name"]}
"""
schedule = build_schedule_from_schema(schema)
assert schedule == [
'\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*',
{"title": "User Id", "type": "integer"},
'[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*',
{"title": "Name", "type": "string"},
'[\\n ]*,[\\n ]*"a"[\\n ]*:[\\n ]*',
{"title": "Name", "type": "string"},
'[\\n ]*,[\\n ]*"b"[\\n ]*:[\\n ]*',
{"title": "Name", "type": "string"},
'[\\n ]*,[\\n ]*"c"[\\n ]*:[\\n ]*',
{"title": "Name", "type": "string"},
"[\\n ]*\\}",
]


def test_json_schema_with_def_ref():
schema = """{
"title": "User",
"type": "object",
"$defs": {
"name": {"title": "Name2", "type": "string"}
},
"properties": {
"user_id": {"title": "User Id", "type": "integer"},
"name": {"title": "Name", "type": "string"},
"name2": {"$ref": "#/$defs/name"}
},
"required": ["user_id", "name"]}
"""
schedule = build_schedule_from_schema(schema)
assert schedule == [
'\\{[\\n ]*"user_id"[\\n ]*:[\\n ]*',
{"title": "User Id", "type": "integer"},
'[\\n ]*,[\\n ]*"name"[\\n ]*:[\\n ]*',
{"title": "Name", "type": "string"},
'[\\n ]*,[\\n ]*"name2"[\\n ]*:[\\n ]*',
{"title": "Name2", "type": "string"},
"[\\n ]*\\}",
]


def test_json_schema_with_bundled_ref():
schema = """{
"$id": "https://example.com/schemas/customer",
"$schema": "https://json-schema.org/draft/2020-12/schema",
"title": "Customer",
"type": "object",
"properties": {
"first_name": { "type": "string" },
"last_name": { "type": "string" },
"shipping_address": { "$ref": "/schemas/address" },
"billing_address": { "$ref": "/schemas/address" }
},
"required": ["first_name", "last_name", "shipping_address", "billing_address"],
"$defs": {
"address": {
"title": "Address",
"$id": "/schemas/address",
"$schema": "http://json-schema.org/draft-07/schema#",
"type": "object",
"properties": {
"street_address": { "type": "string" },
"city": { "type": "string" },
"state": { "$ref": "#/definitions/state" }
},
"required": ["street_address", "city", "state"],
"definitions": {
"state": { "type": "object", "title": "State", "properties": { "name": { "type": "string" } }, "required": ["name"] }
}
}
}
}"""
schedule = build_schedule_from_schema(schema)
assert schedule == [
'\\{[\\n ]*"first_name"[\\n ]*:[\\n ]*',
{"type": "string"},
'[\\n ]*,[\\n ]*"last_name"[\\n ]*:[\\n ]*',
{"type": "string"},
'[\\n ]*,[\\n ]*"shipping_address"[\\n ]*:[\\n ]*\\{[\\n ]*"street_address"[\\n ]*:[\\n ]*',
{"type": "string"},
'[\\n ]*,[\\n ]*"city"[\\n ]*:[\\n ]*',
{"type": "string"},
'[\\n ]*,[\\n ]*"state"[\\n ]*:[\\n ]*\\{[\\n ]*"name"[\\n ]*:[\\n ]*',
{"type": "string"},
'[\\n ]*\\}[\\n ]*\\}[\\n ]*,[\\n ]*"billing_address"[\\n ]*:[\\n ]*\\{[\\n ]*"street_address"[\\n ]*:[\\n ]*',
{"type": "string"},
'[\\n ]*,[\\n ]*"city"[\\n ]*:[\\n ]*',
{"type": "string"},
'[\\n ]*,[\\n ]*"state"[\\n ]*:[\\n ]*\\{[\\n ]*"name"[\\n ]*:[\\n ]*',
{"type": "string"},
"[\\n ]*\\}[\\n ]*\\}[\\n ]*\\}",
]


class MockTokenizer:
pad_token_id = 0
eos_token_id = 0
Expand Down

0 comments on commit 4ce76ff

Please sign in to comment.