Skip to content

Commit

Permalink
Support fixed-length arrays
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf authored and brandonwillard committed Nov 12, 2023
1 parent 650311d commit e73d7fd
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 7 deletions.
23 changes: 16 additions & 7 deletions outlines/text/json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,16 +142,16 @@ def to_regex(resolver: Resolver, instance: dict):
instance_type = instance["type"]
if instance_type == "string":
if "maxLength" in instance or "minLength" in instance:
max_length = instance.get("maxLength", "")
min_length = instance.get("minLength", "")
max_items = instance.get("maxLength", "")
min_items = instance.get("minLength", "")
try:
if int(max_length) < int(min_length):
if int(max_items) < int(min_items):
raise ValueError(
"maxLength must be greater than or equal to minLength"
)
except ValueError:
pass
return f'"{STRING_INNER}{{{min_length},{max_length}}}"'
return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
elif "pattern" in instance:
pattern = instance["pattern"]
if pattern[0] == "^" and pattern[-1] == "$":
Expand All @@ -168,12 +168,19 @@ def to_regex(resolver: Resolver, instance: dict):
return type_to_regex["integer"]

elif instance_type == "array":
min_items = instance.get("minItems", "0")
max_items = instance.get("maxItems", "")
if min_items == max_items:
num_repeats = "{" + str(int(min_items) - 1) + "}"
else:
num_repeats = "*"

if "items" in instance:
items_regex = to_regex(resolver, instance["items"])
return rf"\[({items_regex})(,({items_regex}))*\]"
return rf"\[({items_regex})(,({items_regex})){num_repeats}\]"
else:
# Here we need to make the choice to exclude generating list of objects
# if the specification of the object is not give, even though a JSON
# if the specification of the object is not given, even though a JSON
# object that contains an object here would be valid under the specification.
types = [
{"type": "boolean"},
Expand All @@ -183,7 +190,9 @@ def to_regex(resolver: Resolver, instance: dict):
{"type": "string"},
]
regexes = [to_regex(resolver, t) for t in types]
return rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)}))*\]"
return (
rf"\[({'|'.join(regexes)})(,({'|'.join(regexes)})){num_repeats}\]"
)

elif instance_type == "boolean":
return type_to_regex["boolean"]
Expand Down
24 changes: 24 additions & 0 deletions tests/text/test_json_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,30 @@ def test_match_number(pattern, does_match):
rf"\[({NUMBER})(,({NUMBER}))*\]",
[("[1e+9,1.3]", True)],
),
# array with a set length of 1
(
{
"title": "Foo",
"type": "array",
"items": {"type": "integer"},
"minItems": 1,
"maxItems": 1,
},
rf"\[({INTEGER})(,({INTEGER})){{0}}\]",
[("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)],
),
# array with a set length greather than 1
(
{
"title": "Foo",
"type": "array",
"items": {"type": "integer"},
"minItems": 3,
"maxItems": 3,
},
rf"\[({INTEGER})(,({INTEGER})){{2}}\]",
[("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)],
),
# oneOf
(
{
Expand Down

0 comments on commit e73d7fd

Please sign in to comment.