-
Made with ❤️ by the team at .txt
diff --git a/docs/reference/generation/generation.md b/docs/reference/generation/generation.md
index a14818514..930ad9d22 100644
--- a/docs/reference/generation/generation.md
+++ b/docs/reference/generation/generation.md
@@ -4,7 +4,7 @@ title: Generation
# Generation
-Once an [Outlines model](../models) is constructed you can use `outlines.generate` to generate text. Standard LLM generation is possible via `outlines.generate.text`, along with a variety of structured generation methods described below. (For a detailed technical explanation of how structured generation works, you may review the [Structured Generation Explanation](./structured_generation_explanation.md) page)
+Once an [Outlines model](../models/models.md) is constructed you can use `outlines.generate` to generate text. Standard LLM generation is possible via `outlines.generate.text`, along with a variety of structured generation methods described below. (For a detailed technical explanation of how structured generation works, you may review the [Structured Generation Explanation](./structured_generation_explanation.md) page)
Before generating text, you must construct an `outlines.model`. Example:
diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md
index 24b0fdc97..51b62eca8 100644
--- a/docs/reference/models/llamacpp.md
+++ b/docs/reference/models/llamacpp.md
@@ -4,7 +4,11 @@ Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/l
!!! Note "Installation"
- You need to install the `llama-cpp-python` library to use the llama.cpp integration. See the [installation section](#installation) for instructions to install `llama-cpp-python` with CUDA, Metal, ROCm and other backends.
+ You need to install the `llama-cpp-python` library to use the llama.cpp integration. See the [installation section](#installation) for instructions to install `llama-cpp-python` with CUDA, Metal, ROCm and other backends. To get started quickly you can also run:
+
+ ```bash
+ pip install "outlines[llamacpp]"
+ ```
## Load the model
diff --git a/docs/reference/models/mlxlm.md b/docs/reference/models/mlxlm.md
index cf7bb7443..d435b9c1f 100644
--- a/docs/reference/models/mlxlm.md
+++ b/docs/reference/models/mlxlm.md
@@ -4,7 +4,11 @@ Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx
!!! Note "Installation"
- You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration.
+ You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration. To get started quickly you can also run:
+
+ ```bash
+ pip install "outlines[mlxlm]"
+ ```
## Load the model
diff --git a/docs/reference/models/openai.md b/docs/reference/models/openai.md
index 5c737c916..638107568 100644
--- a/docs/reference/models/openai.md
+++ b/docs/reference/models/openai.md
@@ -2,7 +2,11 @@
!!! Installation
- You need to install the `openai` library to be able to use the OpenAI API in Outlines.
+ You need to install the `openai` library to be able to use the OpenAI API in Outlines. Or alternatively:
+
+ ```bash
+ pip install "outlines[openai]"
+ ```
## OpenAI models
diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md
index 2a13e28ec..f4c319540 100644
--- a/docs/reference/models/transformers.md
+++ b/docs/reference/models/transformers.md
@@ -3,10 +3,10 @@
!!! Installation
- You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines:
+ You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines, or alternatively:
```bash
- pip install torch transformers datasets
+ pip install "outlines[transformers]"
```
diff --git a/docs/reference/models/vllm.md b/docs/reference/models/vllm.md
index fb1c830fa..8789b588e 100644
--- a/docs/reference/models/vllm.md
+++ b/docs/reference/models/vllm.md
@@ -3,7 +3,11 @@
!!! Note "Installation"
- You need to install the `vllm` library to use the vLLM integration. See the [installation section](#installation) for instructions to install vLLM for CPU or ROCm.
+ You need to install the `vllm` library to use the vLLM integration. See the [installation section](#installation) for instructions to install vLLM for CPU or ROCm. To get started you can also run:
+
+ ```bash
+ pip install "outlines[vllm]"
+ ```
## Load the model
diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css
index 33783c153..e71a6124f 100644
--- a/docs/stylesheets/extra.css
+++ b/docs/stylesheets/extra.css
@@ -112,6 +112,9 @@ h1.title {
h2.subtitle {
margin: 5px 0px 25px;
+ font-size: 1rem;
+ max-width: 540px;
+ margin: 0 auto;
}
.md-typeset {
diff --git a/examples/beam-cloud/README.md b/examples/beam-cloud/README.md
new file mode 100644
index 000000000..5f190d76f
--- /dev/null
+++ b/examples/beam-cloud/README.md
@@ -0,0 +1,5 @@
+## Deploy Outlines on Beam
+
+1. Create an account [here](https://beam.cloud) and install the Beam SDK
+2. Download the `app.py` file to your computer
+3. Deploy it as a serverless API by running: `beam deploy app.py:predict`
diff --git a/examples/beam-cloud/app.py b/examples/beam-cloud/app.py
new file mode 100644
index 000000000..fb6c2cb2b
--- /dev/null
+++ b/examples/beam-cloud/app.py
@@ -0,0 +1,39 @@
+from beam import Image, endpoint, env
+
+if env.is_remote():
+ import outlines
+
+
+# Pre-load models when the container first starts
+def load_models():
+ import outlines
+
+ model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct")
+ return model
+
+
+@endpoint(
+ name="outlines-serverless",
+ gpu="A10G",
+ cpu=1,
+ memory="16Gi",
+ on_start=load_models,
+ image=Image().add_python_packages(
+ ["outlines", "torch", "transformers", "accelerate"]
+ ),
+)
+def predict(context, **inputs):
+ default_prompt = """You are a sentiment-labelling assistant.
+ Is the following review positive or negative?
+
+ Review: This restaurant is just awesome!
+ """
+
+ prompt = inputs.get("prompt", default_prompt)
+
+ # Unpack cached model from context
+ model = context.on_start_value
+ # Inference
+ generator = outlines.generate.choice(model, ["Positive", "Negative"])
+ answer = generator(prompt)
+ return {"answer": answer}
diff --git a/mkdocs.yml b/mkdocs.yml
index b97a2e41e..4888309b9 100644
--- a/mkdocs.yml
+++ b/mkdocs.yml
@@ -74,6 +74,7 @@ markdown_extensions:
- pymdownx.emoji:
emoji_index: !!python/name:material.extensions.emoji.twemoji
emoji_generator: !!python/name:material.extensions.emoji.to_svg
+ - pymdownx.snippets:
extra_css:
@@ -120,7 +121,10 @@ nav:
- Chain of Thought (CoT): cookbook/chain_of_thought.md
- ReAct Agent: cookbook/react_agent.md
- Vision-Language Models: cookbook/atomic_caption.md
+ - Structured Generation from PDFs: cookbook/read-pdfs.md
- Earnings reports to CSV: cookbook/earnings-reports.md
+ - Digitizing receipts with vision models: cookbook/receipt-digitization.md
+ - Extract events details from text: cookbook/extract_event_details.md
- Run on the cloud:
- BentoML: cookbook/deploy-using-bentoml.md
- Cerebrium: cookbook/deploy-using-cerebrium.md
diff --git a/outlines/base.py b/outlines/base.py
index 4de8ccf5a..29d42c54c 100644
--- a/outlines/base.py
+++ b/outlines/base.py
@@ -5,12 +5,23 @@
from typing import Callable, Optional
import numpy as np
-from numpy.lib.function_base import (
- _calculate_shapes,
- _parse_gufunc_signature,
- _parse_input_dimensions,
- _update_dim_sizes,
-)
+
+# Import required functions based on NumPy version
+np_major_version = int(np.__version__.split(".")[0])
+if np_major_version >= 2:
+ from numpy.lib._function_base_impl import (
+ _calculate_shapes,
+ _parse_gufunc_signature,
+ _parse_input_dimensions,
+ _update_dim_sizes,
+ )
+else:
+ from numpy.lib.function_base import (
+ _calculate_shapes,
+ _parse_gufunc_signature,
+ _parse_input_dimensions,
+ _update_dim_sizes,
+ )
# Allow nested loops for running in notebook. We don't enable it globally as it
# may interfere with other libraries that use asyncio.
diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py
index bbd3f44ba..9df2e9d12 100644
--- a/outlines/fsm/guide.py
+++ b/outlines/fsm/guide.py
@@ -15,7 +15,6 @@
)
from outlines import grammars
-from outlines.caching import cache
from outlines.fsm.parsing import PartialLark, PartialParserState
if TYPE_CHECKING:
@@ -73,7 +72,6 @@ def copy(self):
return self
-@cache()
def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs):
return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs)
diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py
index 98d2de59c..578ee7626 100644
--- a/outlines/fsm/json_schema.py
+++ b/outlines/fsm/json_schema.py
@@ -1,89 +1,10 @@
import inspect
import json
-import re
import warnings
-from typing import Callable, Optional, Tuple, Type, Union
+from enum import Enum
+from typing import Callable, Type, Union
-from jsonschema.protocols import Validator
from pydantic import BaseModel, create_model
-from referencing import Registry, Resource
-from referencing._core import Resolver
-from referencing.jsonschema import DRAFT202012
-
-# allow `\"`, `\\`, or any character which isn't a control sequence
-STRING_INNER = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])'
-STRING = f'"{STRING_INNER}*"'
-
-INTEGER = r"(-)?(0|[1-9][0-9]*)"
-NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?"
-BOOLEAN = r"(true|false)"
-NULL = r"null"
-WHITESPACE = r"[ ]?"
-
-type_to_regex = {
- "string": STRING,
- "integer": INTEGER,
- "number": NUMBER,
- "boolean": BOOLEAN,
- "null": NULL,
-}
-
-DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"'
-DATE = r'"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"'
-TIME = r'"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"'
-UUID = r'"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"'
-
-format_to_regex = {
- "uuid": UUID,
- "date-time": DATE_TIME,
- "date": DATE,
- "time": TIME,
-}
-
-
-def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None):
- """Turn a JSON schema into a regex that matches any JSON object that follows
- this schema.
-
- JSON Schema is a declarative language that allows to annotate JSON documents
- with types and descriptions. These schemas can be generated from any Python
- datastructure that has type annotation: namedtuples, dataclasses, Pydantic
- models. And by ensuring that the generation respects the schema we ensure
- that the output can be parsed into these objects.
- This function parses the provided schema and builds a generation schedule which
- mixes deterministic generation (fixed strings), and sampling with constraints.
-
- Parameters
- ----------
- schema
- A string that represents a JSON Schema.
- whitespace_pattern
- Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
- Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
-
- Returns
- -------
- A generation schedule. A list of strings that represent the JSON
- schema's structure and regular expression that define the structure of
- the fields.
-
- References
- ----------
- .. [0] JSON Schema. https://json-schema.org/
-
- """
-
- schema = json.loads(schema)
- Validator.check_schema(schema)
-
- # Build reference resolver
- schema = Resource(contents=schema, specification=DRAFT202012)
- uri = schema.id() if schema.id() is not None else ""
- registry = Registry().with_resource(uri=uri, resource=schema)
- resolver = registry.resolver()
-
- content = schema.contents
- return to_regex(resolver, content, whitespace_pattern)
def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str:
@@ -119,412 +40,7 @@ def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -
return schema_str
-def _get_num_items_pattern(min_items, max_items, whitespace_pattern):
- # Helper function for arrays and objects
- min_items = int(min_items or 0)
- if max_items is None:
- return rf"{{{max(min_items - 1, 0)},}}"
- else:
- max_items = int(max_items)
- if max_items < 1:
- return None
- return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}"
-
-
-def validate_quantifiers(
- min_bound: Optional[str], max_bound: Optional[str], start_offset: int = 0
-) -> Tuple[str, str]:
- """
- Ensures that the bounds of a number are valid. Bounds are used as quantifiers in the regex.
-
- Parameters
- ----------
- min_bound
- The minimum value that the number can take.
- max_bound
- The maximum value that the number can take.
- start_offset
- Number of elements that are already present in the regex but still need to be counted.
- ex: if the regex is already "(-)?(0|[1-9][0-9])", we will always have at least 1 digit, so the start_offset is 1.
-
- Returns
- -------
- min_bound
- The minimum value that the number can take.
- max_bound
- The maximum value that the number can take.
-
- Raises
- ------
- ValueError
- If the minimum bound is greater than the maximum bound.
-
- TypeError or ValueError
- If the minimum bound is not an integer or None.
- or
- If the maximum bound is not an integer or None.
- """
- min_bound = "" if min_bound is None else str(int(min_bound) - start_offset)
- max_bound = "" if max_bound is None else str(int(max_bound) - start_offset)
- if min_bound and max_bound:
- if int(max_bound) < int(min_bound):
- raise ValueError("max bound must be greater than or equal to min bound")
- return min_bound, max_bound
-
-
-def to_regex(
- resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None
-):
- """Translate a JSON Schema instance into a regex that validates the schema.
-
- Note
- ----
- Many features of JSON schema are missing:
- - Handle `additionalProperties` keyword
- - Handle types defined as a list
- - Handle constraints on numbers
- - Handle special patterns: `date`, `uri`, etc.
-
- This does not support recursive definitions.
-
- Parameters
- ----------
- resolver
- An object that resolves references to other instances within a schema
- instance
- The instance to translate
- whitespace_pattern
- Pattern to use for JSON syntactic whitespace (doesn't impact string literals)
- Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"`
- """
-
- # set whitespace pattern
- if whitespace_pattern is None:
- whitespace_pattern = WHITESPACE
-
- if instance == {}:
- # JSON Schema Spec: Empty object means unconstrained, any json type is legal
- types = [
- {"type": "boolean"},
- {"type": "null"},
- {"type": "number"},
- {"type": "integer"},
- {"type": "string"},
- {"type": "array"},
- {"type": "object"},
- ]
- regexes = [to_regex(resolver, t, whitespace_pattern) for t in types]
- regexes = [rf"({r})" for r in regexes]
- return rf"{'|'.join(regexes)}"
-
- elif "properties" in instance:
- regex = ""
- regex += r"\{"
- properties = instance["properties"]
- required_properties = instance.get("required", [])
- is_required = [item in required_properties for item in properties]
- # If at least one property is required, we include the one in the lastest position
- # without any comma.
- # For each property before it (optional or required), we add with a comma after the property.
- # For each property after it (optional), we add with a comma before the property.
- if any(is_required):
- last_required_pos = max([i for i, value in enumerate(is_required) if value])
- for i, (name, value) in enumerate(properties.items()):
- subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}'
- subregex += to_regex(resolver, value, whitespace_pattern)
- if i < last_required_pos:
- subregex = f"{subregex}{whitespace_pattern},"
- elif i > last_required_pos:
- subregex = f"{whitespace_pattern},{subregex}"
- regex += subregex if is_required[i] else f"({subregex})?"
- # If no property is required, we have to create a possible pattern for each property in which
- # it's the last one necessarilly present. Then, we add the others as optional before and after
- # following the same strategy as described above.
- # The whole block is made optional to allow the case in which no property is returned.
- else:
- property_subregexes = []
- for i, (name, value) in enumerate(properties.items()):
- subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}'
- subregex += to_regex(resolver, value, whitespace_pattern)
- property_subregexes.append(subregex)
- possible_patterns = []
- for i in range(len(property_subregexes)):
- pattern = ""
- for subregex in property_subregexes[:i]:
- pattern += f"({subregex}{whitespace_pattern},)?"
- pattern += property_subregexes[i]
- for subregex in property_subregexes[i + 1 :]:
- pattern += f"({whitespace_pattern},{subregex})?"
- possible_patterns.append(pattern)
- regex += f"({'|'.join(possible_patterns)})?"
-
- regex += f"{whitespace_pattern}" + r"\}"
-
- return regex
-
- # To validate against allOf, the given data must be valid against all of the
- # given subschemas.
- elif "allOf" in instance:
- subregexes = [
- to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"]
- ]
- subregexes_str = [f"{subregex}" for subregex in subregexes]
- return rf"({''.join(subregexes_str)})"
-
- # To validate against `anyOf`, the given data must be valid against
- # any (one or more) of the given subschemas.
- elif "anyOf" in instance:
- subregexes = [
- to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"]
- ]
- return rf"({'|'.join(subregexes)})"
-
- # To validate against oneOf, the given data must be valid against exactly
- # one of the given subschemas.
- elif "oneOf" in instance:
- subregexes = [
- to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"]
- ]
-
- xor_patterns = [f"(?:{subregex})" for subregex in subregexes]
-
- return rf"({'|'.join(xor_patterns)})"
-
- # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx
- elif "prefixItems" in instance:
- element_patterns = [
- to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"]
- ]
- comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}"
- tuple_inner = comma_split_pattern.join(element_patterns)
- return rf"\[{whitespace_pattern}{tuple_inner}{whitespace_pattern}\]"
-
- # The enum keyword is used to restrict a value to a fixed set of values. It
- # must be an array with at least one element, where each element is unique.
- elif "enum" in instance:
- choices = []
- for choice in instance["enum"]:
- if type(choice) in [int, float, bool, type(None), str]:
- choices.append(re.escape(json.dumps(choice)))
- else:
- raise TypeError(f"Unsupported data type in enum: {type(choice)}")
- return f"({'|'.join(choices)})"
-
- elif "const" in instance:
- const = instance["const"]
- if type(const) in [int, float, bool, type(None), str]:
- const = re.escape(json.dumps(const))
- else:
- raise TypeError(f"Unsupported data type in const: {type(const)}")
- return const
-
- elif "$ref" in instance:
- path = f"{instance['$ref']}"
- instance = resolver.lookup(path).contents
- return to_regex(resolver, instance, whitespace_pattern)
-
- # The type keyword may either be a string or an array:
- # - If it's a string, it is the name of one of the basic types.
- # - If it is an array, it must be an array of strings, where each string is
- # the name of one of the basic types, and each element is unique. In this
- # case, the JSON snippet is valid if it matches any of the given types.
- elif "type" in instance:
- instance_type = instance["type"]
- if instance_type == "string":
- if "maxLength" in instance or "minLength" in instance:
- max_items = instance.get("maxLength", "")
- min_items = instance.get("minLength", "")
- try:
- if int(max_items) < int(min_items):
- raise ValueError(
- "maxLength must be greater than or equal to minLength"
- ) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume)
- except ValueError:
- pass
- return f'"{STRING_INNER}{{{min_items},{max_items}}}"'
- elif "pattern" in instance:
- pattern = instance["pattern"]
- if pattern[0] == "^" and pattern[-1] == "$":
- return rf'("{pattern[1:-1]}")'
- else:
- return rf'("{pattern}")'
- elif "format" in instance:
- format = instance["format"]
- if format == "date-time":
- return format_to_regex["date-time"]
- elif format == "uuid":
- return format_to_regex["uuid"]
- elif format == "date":
- return format_to_regex["date"]
- elif format == "time":
- return format_to_regex["time"]
- else:
- raise NotImplementedError(
- f"Format {format} is not supported by Outlines"
- )
- else:
- return type_to_regex["string"]
-
- elif instance_type == "number":
- bounds = {
- "minDigitsInteger",
- "maxDigitsInteger",
- "minDigitsFraction",
- "maxDigitsFraction",
- "minDigitsExponent",
- "maxDigitsExponent",
- }
- if bounds.intersection(set(instance.keys())):
- min_digits_integer, max_digits_integer = validate_quantifiers(
- instance.get("minDigitsInteger"),
- instance.get("maxDigitsInteger"),
- start_offset=1,
- )
- min_digits_fraction, max_digits_fraction = validate_quantifiers(
- instance.get("minDigitsFraction"), instance.get("maxDigitsFraction")
- )
- min_digits_exponent, max_digits_exponent = validate_quantifiers(
- instance.get("minDigitsExponent"), instance.get("maxDigitsExponent")
- )
- integers_quantifier = (
- f"{{{min_digits_integer},{max_digits_integer}}}"
- if min_digits_integer or max_digits_integer
- else "*"
- )
- fraction_quantifier = (
- f"{{{min_digits_fraction},{max_digits_fraction}}}"
- if min_digits_fraction or max_digits_fraction
- else "+"
- )
- exponent_quantifier = (
- f"{{{min_digits_exponent},{max_digits_exponent}}}"
- if min_digits_exponent or max_digits_exponent
- else "+"
- )
- return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?"
- return type_to_regex["number"]
-
- elif instance_type == "integer":
- if "minDigits" in instance or "maxDigits" in instance:
- min_digits, max_digits = validate_quantifiers(
- instance.get("minDigits"), instance.get("maxDigits"), start_offset=1
- )
- return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})"
- return type_to_regex["integer"]
-
- elif instance_type == "array":
- num_repeats = _get_num_items_pattern(
- instance.get("minItems"), instance.get("maxItems"), whitespace_pattern
- )
- if num_repeats is None:
- return rf"\[{whitespace_pattern}\]"
-
- allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else ""
-
- if "items" in instance:
- items_regex = to_regex(resolver, instance["items"], whitespace_pattern)
- return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]"
- else:
- # Here we need to make the choice to exclude generating list of objects
- # if the specification of the object is not given, even though a JSON
- # object that contains an object here would be valid under the specification.
- legal_types = [
- {"type": "boolean"},
- {"type": "null"},
- {"type": "number"},
- {"type": "integer"},
- {"type": "string"},
- ]
- depth = instance.get("depth", 2)
- if depth > 0:
- legal_types.append({"type": "object", "depth": depth - 1})
- legal_types.append({"type": "array", "depth": depth - 1})
-
- regexes = [
- to_regex(resolver, t, whitespace_pattern) for t in legal_types
- ]
- return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]"
-
- elif instance_type == "object":
- # pattern for json object with values defined by instance["additionalProperties"]
- # enforces value type constraints recursively, "minProperties", and "maxProperties"
- # doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of"
- num_repeats = _get_num_items_pattern(
- instance.get("minProperties"),
- instance.get("maxProperties"),
- whitespace_pattern,
- )
- if num_repeats is None:
- return rf"\{{{whitespace_pattern}\}}"
-
- allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else ""
-
- additional_properties = instance.get("additionalProperties")
-
- if additional_properties is None or additional_properties is True:
- # JSON Schema behavior: If the additionalProperties of an object is
- # unset or True, it is unconstrained object.
- # We handle this by setting additionalProperties to anyOf: {all types}
-
- legal_types = [
- {"type": "string"},
- {"type": "number"},
- {"type": "boolean"},
- {"type": "null"},
- ]
-
- # We set the object depth to 2 to keep the expression finite, but the "depth"
- # key is not a true component of the JSON Schema specification.
- depth = instance.get("depth", 2)
- if depth > 0:
- legal_types.append({"type": "object", "depth": depth - 1})
- legal_types.append({"type": "array", "depth": depth - 1})
- additional_properties = {"anyOf": legal_types}
-
- value_pattern = to_regex(
- resolver, additional_properties, whitespace_pattern
- )
- key_value_pattern = (
- f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}"
- )
- key_value_successor_pattern = (
- f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}"
- )
- multiple_key_value_pattern = f"({key_value_pattern}({key_value_successor_pattern}){num_repeats}){allow_empty}"
-
- return (
- r"\{"
- + whitespace_pattern
- + multiple_key_value_pattern
- + whitespace_pattern
- + r"\}"
- )
-
- elif instance_type == "boolean":
- return type_to_regex["boolean"]
-
- elif instance_type == "null":
- return type_to_regex["null"]
-
- elif isinstance(instance_type, list):
- # Here we need to make the choice to exclude generating an object
- # if the specification of the object is not give, even though a JSON
- # object that contains an object here would be valid under the specification.
- regexes = [
- to_regex(resolver, {"type": t}, whitespace_pattern)
- for t in instance_type
- if t != "object"
- ]
- return rf"({'|'.join(regexes)})"
-
- raise NotImplementedError(
- f"""Could not translate the instance {instance} to a
- regular expression. Make sure it is valid to the JSON Schema specification. If
- it is, please open an issue on the Outlines repository"""
- )
-
-
-def get_schema_from_signature(fn: Callable) -> str:
+def get_schema_from_signature(fn: Callable) -> dict:
"""Turn a function signature into a JSON schema.
Every JSON object valid to the output JSON Schema can be passed
@@ -550,3 +66,18 @@ def get_schema_from_signature(fn: Callable) -> str:
model = create_model(fn_name, **arguments)
return model.model_json_schema()
+
+
+def get_schema_from_enum(myenum: type[Enum]) -> dict:
+ if len(myenum) == 0:
+ raise ValueError(
+ f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)"
+ )
+ choices = [
+ get_schema_from_signature(elt.value.func)
+ if callable(elt.value)
+ else {"const": elt.value}
+ for elt in myenum
+ ]
+ schema = {"title": myenum.__name__, "oneOf": choices}
+ return schema
diff --git a/outlines/generate/choice.py b/outlines/generate/choice.py
index 595513d52..afb998f52 100644
--- a/outlines/generate/choice.py
+++ b/outlines/generate/choice.py
@@ -1,7 +1,12 @@
import json as pyjson
+import re
+from enum import Enum
from functools import singledispatch
-from typing import Callable, List
+from typing import Callable, List, Union
+from outlines_core.fsm.json_schema import build_regex_from_schema
+
+from outlines.fsm.json_schema import get_schema_from_enum
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial
@@ -12,12 +17,19 @@
@singledispatch
def choice(
- model, choices: List[str], sampler: Sampler = multinomial()
+ model, choices: Union[List[str], type[Enum]], sampler: Sampler = multinomial()
) -> SequenceGeneratorAdapter:
- regex_str = r"(" + r"|".join(choices) + r")"
+ if isinstance(choices, type(Enum)):
+ regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices)))
+ else:
+ choices = [re.escape(choice) for choice in choices] # type: ignore
+ regex_str = r"(" + r"|".join(choices) + r")"
generator = regex(model, regex_str, sampler)
- generator.format_sequence = lambda x: x
+ if isinstance(choices, type(Enum)):
+ generator.format_sequence = lambda x: pyjson.loads(x)
+ else:
+ generator.format_sequence = lambda x: x
return generator
diff --git a/outlines/generate/json.py b/outlines/generate/json.py
index f75878d29..d098d920d 100644
--- a/outlines/generate/json.py
+++ b/outlines/generate/json.py
@@ -1,10 +1,12 @@
import json as pyjson
+from enum import Enum
from functools import singledispatch
from typing import Callable, Optional, Union
+from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
-from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature
+from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature
from outlines.generate.api import SequenceGeneratorAdapter
from outlines.models import OpenAI
from outlines.samplers import Sampler, multinomial
@@ -48,6 +50,11 @@ def json(
regex_str = build_regex_from_schema(schema, whitespace_pattern)
generator = regex(model, regex_str, sampler)
generator.format_sequence = lambda x: schema_object.parse_raw(x)
+ elif isinstance(schema_object, type(Enum)):
+ schema = pyjson.dumps(get_schema_from_enum(schema_object))
+ regex_str = build_regex_from_schema(schema, whitespace_pattern)
+ generator = regex(model, regex_str, sampler)
+ generator.format_sequence = lambda x: pyjson.loads(x)
elif callable(schema_object):
schema = pyjson.dumps(get_schema_from_signature(schema_object))
regex_str = build_regex_from_schema(schema, whitespace_pattern)
diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py
index d28fcb2d7..fe6f861ac 100644
--- a/outlines/models/__init__.py
+++ b/outlines/models/__init__.py
@@ -16,4 +16,4 @@
from .transformers_vision import TransformersVision, transformers_vision
from .vllm import VLLM, vllm
-LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM]
+LogitsGenerator = Union[Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM]
diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py
index 6e63ef5b6..d8b7e032c 100644
--- a/outlines/models/mlxlm.py
+++ b/outlines/models/mlxlm.py
@@ -167,12 +167,7 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]:
prob = softmax_logits[0, token]
return token, prob
- kv_heads = (
- [self.model.n_kv_heads] * len(self.model.layers)
- if isinstance(self.model.n_kv_heads, int)
- else self.model.n_kv_heads
- )
- cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads]
+ cache = mlx_lm.models.cache.make_prompt_cache(self.model)
# kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model()
unprocessed_input_ids = prompt
diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py
index 7ecc9013f..444492500 100644
--- a/outlines/models/transformers.py
+++ b/outlines/models/transformers.py
@@ -2,8 +2,6 @@
import inspect
from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union
-from datasets.fingerprint import Hasher
-
from outlines.generate.api import GenerationParameters, SamplingParameters
from outlines.models.tokenizer import Tokenizer
@@ -116,6 +114,8 @@ def __eq__(self, other):
return NotImplemented
def __hash__(self):
+ from datasets.fingerprint import Hasher
+
return hash(Hasher.hash(self.tokenizer))
def __getstate__(self):
diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py
index d1f97bde2..778c27c6f 100644
--- a/outlines/models/vllm.py
+++ b/outlines/models/vllm.py
@@ -1,11 +1,10 @@
import dataclasses
from typing import TYPE_CHECKING, List, Optional, Union
-from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase
-
from outlines.generate.api import GenerationParameters, SamplingParameters
if TYPE_CHECKING:
+ from transformers import PreTrainedTokenizerBase
from vllm import LLM
from vllm.sampling_params import SamplingParams
@@ -188,7 +187,7 @@ def vllm(model_name: str, **vllm_model_params):
return VLLM(model)
-def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase:
+def adapt_tokenizer(tokenizer: "PreTrainedTokenizerBase") -> "PreTrainedTokenizerBase":
"""Adapt a tokenizer to use to compile the FSM.
The API of Outlines tokenizers is slightly different to that of `transformers`. In
@@ -205,6 +204,8 @@ def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBa
PreTrainedTokenizerBase
The adapted tokenizer.
"""
+ from transformers import SPIECE_UNDERLINE
+
tokenizer.vocabulary = tokenizer.get_vocab()
tokenizer.special_tokens = set(tokenizer.all_special_tokens)
diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py
index 9a52abecd..44b55af2e 100644
--- a/outlines/processors/base_logits_processor.py
+++ b/outlines/processors/base_logits_processor.py
@@ -20,6 +20,16 @@ def is_mlx_array_type(array_type):
return issubclass(array_type, mx.array)
+def is_jax_array_type(array_type):
+ try:
+ import jaxlib
+ except ImportError:
+ return False
+ return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance(
+ array_type, jaxlib.xla_extension.ArrayImpl
+ )
+
+
class OutlinesLogitsProcessor(Protocol):
"""
Base class for logits processors which normalizes types of logits:
@@ -97,9 +107,18 @@ def _to_torch(tensor_like: Array) -> torch.Tensor:
return torch.tensor(tensor_like)
elif is_mlx_array_type(type(tensor_like)):
- # mlx -> torch -> mlx conversion docs:
- # https://ml-explore.github.io/mlx/build/html/usage/numpy.html
- return torch.from_dlpack(tensor_like)
+ import mlx.core as mx
+
+ # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch
+ return torch.from_dlpack(
+ np.array(tensor_like.astype(mx.float32), copy=False)
+ )
+
+ elif is_jax_array_type(type(tensor_like)):
+ import jax
+
+ torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like))
+ return torch_tensor
else:
raise TypeError(
@@ -129,6 +148,11 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array:
# numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch
return mx.array(tensor.float().numpy())
+ elif is_jax_array_type(target_type):
+ import jax
+
+ return jax.dlpack.from_dlpack(tensor)
+
else:
raise TypeError(
f"Failed to convert torch tensors to target_type `{target_type}`"
diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py
index 50ae6e3ee..6a0d4236e 100644
--- a/outlines/processors/structured.py
+++ b/outlines/processors/structured.py
@@ -27,10 +27,11 @@
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union
import torch
+from outlines_core.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from outlines.fsm.guide import CFGGuide, Guide, RegexGuide
-from outlines.fsm.json_schema import build_regex_from_schema, convert_json_schema_to_str
+from outlines.fsm.json_schema import convert_json_schema_to_str
from .base_logits_processor import OutlinesLogitsProcessor
diff --git a/pyproject.toml b/pyproject.toml
index fa7005afd..b83275f89 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -28,7 +28,7 @@ dependencies = [
"jinja2",
"lark",
"nest_asyncio",
- "numpy<2.0.0",
+ "numpy",
"cloudpickle",
"diskcache",
"pydantic>=2.0",
@@ -36,16 +36,21 @@ dependencies = [
"jsonschema",
"requests",
"tqdm",
- "datasets",
"typing_extensions",
"pycountry",
"airportsdata",
"torch",
- "outlines_core==0.1.14",
+ "outlines_core==0.1.25",
]
dynamic = ["version"]
[project.optional-dependencies]
+vllm = ["vllm", "transformers", "numpy<2"]
+transformers = ["transformers", "accelerate", "datasets", "numpy<2"]
+mlxlm = ["mlx-lm", "datasets"]
+openai = ["openai"]
+llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"]
+exllamav2 = ["exllamav2"]
test = [
"pre-commit",
"pytest",
@@ -58,13 +63,15 @@ test = [
"beartype<0.16.0",
"responses",
"llama-cpp-python",
- "mlx-lm; platform_machine == 'arm64' and sys_platform == 'darwin'",
+ "mlx-lm>=0.19.2; platform_machine == 'arm64' and sys_platform == 'darwin'",
"huggingface_hub",
"openai>=1.0.0",
+ "datasets",
"vllm; sys_platform != 'darwin'",
"transformers",
"pillow",
"exllamav2",
+ "jax"
]
serve = [
"vllm>=0.3.0",
@@ -109,6 +116,9 @@ enable_incomplete_feature = ["Unpack"]
[[tool.mypy.overrides]]
module = [
"exllamav2.*",
+ "jax",
+ "jaxlib",
+ "jax.numpy",
"jinja2",
"jsonschema.*",
"openai.*",
diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py
index 510faf4b0..bf25c43c4 100644
--- a/tests/fsm/test_guide.py
+++ b/tests/fsm/test_guide.py
@@ -59,7 +59,7 @@ def convert_token_to_string(self, token):
tokenizer = MockTokenizer()
fsm = RegexGuide.from_regex(regex_str, tokenizer)
- assert fsm.states_to_token_maps == {0: {1: 1}}
+ assert fsm.states_to_token_maps.get_transitions() == {0: {1: 1}}
instruction = fsm.get_next_instruction(0)
assert isinstance(instruction, Generate)
@@ -70,9 +70,6 @@ def convert_token_to_string(self, token):
assert fsm.is_final_state(0) is False
- for state in fsm.final_states:
- assert fsm.is_final_state(state) is True
-
def test_regex_multi_byte_llama_like():
class MockTokenizer:
@@ -100,7 +97,7 @@ def convert_token_to_string(self, token):
tokenizer = MockTokenizer()
fsm = RegexGuide.from_regex(regex_str, tokenizer)
- assert fsm.states_to_token_maps == {
+ assert fsm.states_to_token_maps.get_transitions() == {
0: {5: 1, 4: 2},
1: {6: 3},
3: {7: 4},
@@ -116,9 +113,6 @@ def convert_token_to_string(self, token):
assert fsm.is_final_state(0) is False
- for state in fsm.final_states:
- assert fsm.is_final_state(state) is True
-
def test_regex_multi_byte_gpt2_like():
class MockTokenizer:
@@ -147,7 +141,7 @@ def convert_token_to_string(self, token):
tokenizer = MockTokenizer()
fsm = RegexGuide.from_regex(regex_str, tokenizer)
- assert fsm.states_to_token_maps == {
+ assert fsm.states_to_token_maps.get_transitions() == {
0: {5: 1, 10: 2},
1: {8: 5, 4: 3},
2: {11: 3},
@@ -163,9 +157,6 @@ def convert_token_to_string(self, token):
assert fsm.is_final_state(0) is False
- for state in fsm.final_states:
- assert fsm.is_final_state(state) is True
-
def test_regex_final_state():
"""Make sure that the FSM stays in the final state as we keep generating"""
diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py
index 7565ff642..23864e029 100644
--- a/tests/fsm/test_json_schema.py
+++ b/tests/fsm/test_json_schema.py
@@ -1,27 +1,14 @@
import json
-import re
-from typing import List, Literal, Union
+from contextlib import nullcontext
+from enum import Enum
+from functools import partial
+from typing import List
-import interegular
import pytest
-from pydantic import BaseModel, Field, constr
+from outlines_core.fsm.json_schema import build_regex_from_schema
+from pydantic import BaseModel, constr
-from outlines.fsm.json_schema import (
- BOOLEAN,
- DATE,
- DATE_TIME,
- INTEGER,
- NULL,
- NUMBER,
- STRING,
- STRING_INNER,
- TIME,
- UUID,
- WHITESPACE,
- build_regex_from_schema,
- get_schema_from_signature,
- to_regex,
-)
+from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature
def test_function_basic():
@@ -54,988 +41,38 @@ class User(BaseModel):
is_true: bool
schema = json.dumps(User.model_json_schema())
- schedule = build_regex_from_schema(schema)
- assert isinstance(schedule, str)
-
-
-@pytest.mark.parametrize(
- "pattern,does_match",
- [
- ({"integer": "0"}, True),
- ({"integer": "1"}, True),
- ({"integer": "-1"}, True),
- ({"integer": "01"}, False),
- ({"integer": "1.3"}, False),
- ({"integer": "t"}, False),
- ],
-)
-def test_match_integer(pattern, does_match):
- step = {"title": "Foo", "type": "integer"}
- regex = to_regex(None, step)
- assert regex == INTEGER
-
- value = pattern["integer"]
- match = re.fullmatch(regex, value)
- if does_match:
- assert match[0] == value
- assert match.span() == (0, len(value))
- else:
- assert match is None
+ regex_str = build_regex_from_schema(schema)
+ assert isinstance(regex_str, str)
-@pytest.mark.parametrize(
- "pattern,does_match",
- [
- ({"number": "1"}, True),
- ({"number": "0"}, True),
- ({"number": "01"}, False),
- ({"number": ".3"}, False),
- ({"number": "1.3"}, True),
- ({"number": "-1.3"}, True),
- ({"number": "1.3e9"}, False),
- ({"number": "1.3e+9"}, True),
- ],
-)
-def test_match_number(pattern, does_match):
- step = {"title": "Foo", "type": "number"}
- regex = to_regex(None, step)
- assert regex == NUMBER
+def add(a: float, b: float) -> float:
+ return a + b
- value = pattern["number"]
- match = re.fullmatch(regex, value)
- if does_match:
- assert match[0] == value
- assert match.span() == (0, len(value))
- else:
- assert match is None
+class MyEnum(Enum):
+ add = partial(add)
+ a = "a"
+ b = 2
-@pytest.mark.parametrize(
- "schema,regex,examples",
- [
- # String
- (
- {"title": "Foo", "type": "string"},
- STRING,
- [
- ("unquotedstring", False),
- ('"(parenthesized_string)"', True),
- ('"malformed) parenthesis (((() string"', True),
- ('"quoted_string"', True),
- (r'"escape_\character"', False),
- (r'"double_\\escape"', True),
- (r'"\n"', False),
- (r'"\\n"', True),
- (r'"unescaped " quote"', False),
- (r'"escaped \" quote"', True),
- ],
- ),
- # String with maximum length
- (
- {"title": "Foo", "type": "string", "maxLength": 3},
- f'"{STRING_INNER}{{,3}}"',
- [('"ab"', True), ('"a""', False), ('"abcd"', False)],
- ),
- # String with minimum length
- (
- {"title": "Foo", "type": "string", "minLength": 3},
- f'"{STRING_INNER}{{3,}}"',
- [('"ab"', False), ('"abcd"', True), ('"abc""', False)],
- ),
- # String with both minimum and maximum length
- (
- {"title": "Foo", "type": "string", "minLength": 3, "maxLength": 5},
- f'"{STRING_INNER}{{3,5}}"',
- [('"ab"', False), ('"abcd"', True), ('"abcdef""', False)],
- ),
- # String defined by a regular expression
- (
- {"title": "Foo", "type": "string", "pattern": r"^[a-z]$"},
- r'("[a-z]")',
- [('"a"', True), ('"1"', False)],
- ),
- # Boolean
- (
- {"title": "Foo", "type": "boolean"},
- BOOLEAN,
- [
- ("true", True),
- ("false", True),
- ("null", False),
- ("0", False),
- ],
- ),
- # Null
- (
- {"title": "Foo", "type": "null"},
- NULL,
- [
- ("null", True),
- ("true", False),
- ("0", False),
- ],
- ),
- # Const string
- (
- {"title": "Foo", "const": "Marc", "type": "string"},
- '"Marc"',
- [('"Marc"', True), ('"Jean"', False), ('"John"', False)],
- ),
- # Make sure strings are escaped with regex escaping
- (
- {"title": "Foo", "const": ".*", "type": "string"},
- r'"\.\*"',
- [('".*"', True), (r'"\s*"', False), (r'"\.\*"', False)],
- ),
- # Make sure strings are escaped with JSON escaping
- (
- {"title": "Foo", "const": '"', "type": "string"},
- r'"\\""',
- [('"\\""', True), ('"""', False)],
- ),
- # Const integer
- (
- {"title": "Foo", "const": 0, "type": "integer"},
- "0",
- [("0", True), ("1", False), ("a", False)],
- ),
- # Const float
- (
- {"title": "Foo", "const": 0.2, "type": "float"},
- r"0\.2",
- [("0.2", True), ("032", False)],
- ),
- # Const boolean
- (
- {"title": "Foo", "const": True, "type": "boolean"},
- "true",
- [("true", True), ("True", False)],
- ),
- # Const null
- (
- {"title": "Foo", "const": None, "type": "null"},
- "null",
- [("null", True), ("None", False), ("", False)],
- ),
- # Enum string
- (
- {"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"},
- '("Marc"|"Jean")',
- [('"Marc"', True), ('"Jean"', True), ('"John"', False)],
- ),
- # Make sure strings are escaped with regex and JSON escaping
- (
- {"title": "Foo", "enum": [".*", r"\s*"], "type": "string"},
- r'("\.\*"|"\\\\s\*")',
- [('".*"', True), (r'"\\s*"', True), (r'"\.\*"', False)],
- ),
- # Enum integer
- (
- {"title": "Foo", "enum": [0, 1], "type": "integer"},
- "(0|1)",
- [("0", True), ("1", True), ("a", False)],
- ),
- # Enum mix of types
- (
- {"title": "Foo", "enum": [6, 5.3, "potato", True, None]},
- r'(6|5\.3|"potato"|true|null)',
- [
- ("6", True),
- ("5.3", True),
- ('"potato"', True),
- ("true", True),
- ("null", True),
- ("523", False),
- ("True", False),
- ("None", False),
- ],
- ),
- # integer
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {"count": {"title": "Count", "type": "integer"}},
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}',
- [('{ "count": 100 }', True)],
- ),
- # integer with minimum digits
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {
- "count": {"title": "Count", "type": "integer", "minDigits": 3}
- },
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,})[ ]?\\}',
- [('{ "count": 10 }', False), ('{ "count": 100 }', True)],
- ),
- # integer with maximum digits
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {
- "count": {"title": "Count", "type": "integer", "maxDigits": 3}
- },
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{,2})[ ]?\\}',
- [('{ "count": 100 }', True), ('{ "count": 1000 }', False)],
- ),
- # integer with minimum and maximum digits
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {
- "count": {
- "title": "Count",
- "type": "integer",
- "minDigits": 3,
- "maxDigits": 5,
- }
- },
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,4})[ ]?\\}',
- [
- ('{ "count": 10 }', False),
- ('{ "count": 100 }', True),
- ('{ "count": 10000 }', True),
- ('{ "count": 100000 }', False),
- ],
- ),
- # number
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {"count": {"title": "Count", "type": "number"}},
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}',
- [('{ "count": 100 }', True), ('{ "count": 100.5 }', True)],
- ),
- # number with min and max integer digits
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {
- "count": {
- "title": "Count",
- "type": "number",
- "minDigitsInteger": 3,
- "maxDigitsInteger": 5,
- }
- },
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}',
- [
- ('{ "count": 10.005 }', False),
- ('{ "count": 100.005 }', True),
- ('{ "count": 10000.005 }', True),
- ('{ "count": 100000.005 }', False),
- ],
- ),
- # number with min and max fraction digits
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {
- "count": {
- "title": "Count",
- "type": "number",
- "minDigitsFraction": 3,
- "maxDigitsFraction": 5,
- }
- },
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]{3,5})?([eE][+-][0-9]+)?[ ]?\\}',
- [
- ('{ "count": 1.05 }', False),
- ('{ "count": 1.005 }', True),
- ('{ "count": 1.00005 }', True),
- ('{ "count": 1.000005 }', False),
- ],
- ),
- # number with min and max exponent digits
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {
- "count": {
- "title": "Count",
- "type": "number",
- "minDigitsExponent": 3,
- "maxDigitsExponent": 5,
- }
- },
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]{3,5})?[ ]?\\}',
- [
- ('{ "count": 1.05e1 }', False),
- ('{ "count": 1.05e+001 }', True),
- ('{ "count": 1.05e-00001 }', True),
- ('{ "count": 1.05e0000001 }', False),
- ],
- ),
- # number with min and max integer, fraction and exponent digits
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {
- "count": {
- "title": "Count",
- "type": "number",
- "minDigitsInteger": 3,
- "maxDigitsInteger": 5,
- "minDigitsFraction": 3,
- "maxDigitsFraction": 5,
- "minDigitsExponent": 3,
- "maxDigitsExponent": 5,
- }
- },
- "required": ["count"],
- },
- '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]{3,5})?([eE][+-][0-9]{3,5})?[ ]?\\}',
- [
- ('{ "count": 1.05e1 }', False),
- ('{ "count": 100.005e+001 }', True),
- ('{ "count": 10000.00005e-00001 }', True),
- ('{ "count": 100000.000005e0000001 }', False),
- ],
- ),
- # array
- (
- {"title": "Foo", "type": "array", "items": {"type": "number"}},
- rf"\[{WHITESPACE}(({NUMBER})(,{WHITESPACE}({NUMBER})){{0,}})?{WHITESPACE}\]",
- [("[1e+9,1.3]", True), ("[]", True), ("[1", False)],
- ),
- # array with a set length of 1
- (
- {
- "title": "Foo",
- "type": "array",
- "items": {"type": "integer"},
- "minItems": 1,
- "maxItems": 1,
- },
- rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{0,0}}){WHITESPACE}\]",
- [("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)],
- ),
- # array with a set length greather than 1
- (
- {
- "title": "Foo",
- "type": "array",
- "items": {"type": "integer"},
- "minItems": 3,
- "maxItems": 3,
- },
- rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{2,2}}){WHITESPACE}\]",
- [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)],
- ),
- # array with length 0
- (
- {
- "title": "Foo",
- "type": "array",
- "items": {"type": "integer"},
- "minItems": 0,
- "maxItems": 0,
- },
- rf"\[{WHITESPACE}\]",
- [("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)],
- ),
- # object
- (
- {
- "title": "TestSchema",
- "type": "object",
- "properties": {
- "test_dict": {
- "title": "Test Dict",
- "additionalProperties": {"type": "string"},
- "type": "object",
- }
- },
- "required": ["test_dict"],
- },
- rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""",
- [
- ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True),
- ("""{ "test_dict":{"foo":"bar" }}""", True),
- ("""{ "test_dict":{}}""", True),
- ("""{ "WRONG_KEY":{}}""", False),
- ("""{ "test_dict":{"wrong_type" 1}}""", False),
- ],
- ),
- # object containing object
- (
- {
- "title": "TestSchema",
- "type": "object",
- "properties": {
- "test_dict": {
- "title": "Test Dict",
- "additionalProperties": {
- "additionalProperties": {"type": "integer"},
- "type": "object",
- },
- "type": "object",
- }
- },
- "required": ["test_dict"],
- },
- rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""",
- [
- (
- """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""",
- True,
- ),
- (
- """{"test_dict": {"anykey": {"anykey": 123}, "anykey2": {"bif": 456}}}""",
- True,
- ),
- ("""{"test_dict": {}}""", True),
- ("""{"test_dict": {"dict of empty dicts are ok": {} }}""", True),
- (
- """{"test_dict": {"anykey": {"ONLY Dict[Dict]": 123}, "No Dict[int]" 1: }}""",
- False,
- ),
- ],
- ),
- # oneOf
- (
- {
- "title": "Foo",
- "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}],
- },
- rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))',
- [
- ("12.3", True),
- ("true", True),
- ('"a"', True),
- ("null", False),
- ("", False),
- ("12true", False),
- ('1.3"a"', False),
- ('12.3true"a"', False),
- ],
- ),
- # anyOf
- (
- {
- "title": "Foo",
- "anyOf": [{"type": "string"}, {"type": "integer"}],
- },
- rf"({STRING}|{INTEGER})",
- [("12", True), ('"a"', True), ('1"a"', False)],
- ),
- # allOf
- (
- {
- "title": "Foo",
- "allOf": [{"type": "string"}, {"type": "integer"}],
- },
- rf"({STRING}{INTEGER})",
- [('"a"1', True), ('"a"', False), ('"1"', False)],
- ),
- # Tuple / prefixItems
- (
- {
- "title": "Foo",
- "prefixItems": [{"type": "string"}, {"type": "integer"}],
- },
- rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]",
- [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)],
- ),
- # Nested schema
- (
- {
- "title": "Bar",
- "type": "object",
- "properties": {
- "fuzz": {
- "title": "Foo",
- "type": "object",
- "properties": {"spam": {"title": "Spam", "type": "integer"}},
- "required": ["spam"],
- }
- },
- "required": ["fuzz"],
- },
- f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}',
- [('{ "fuzz": { "spam": 100 }}', True)],
- ),
- # Schema with a reference
- (
- {
- "title": "User",
- "type": "object",
- "properties": {
- "user_id": {"title": "User Id", "type": "integer"},
- "name": {"title": "Name", "type": "string"},
- "a": {"$ref": "#/properties/name"},
- },
- "required": ["user_id", "name", "a"],
- },
- f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}',
- [('{"user_id": 100, "name": "John", "a": "Marc"}', True)],
- ),
- (
- {
- "title": "User",
- "type": "object",
- "$defs": {"name": {"title": "Name2", "type": "string"}},
- "properties": {
- "user_id": {"title": "User Id", "type": "integer"},
- "name": {"title": "Name", "type": "string"},
- "name2": {"$ref": "#/$defs/name"},
- },
- "required": ["user_id", "name", "name2"],
- },
- f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}',
- [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)],
- ),
- (
- {
- "$id": "customer",
- "$schema": "https://json-schema.org/draft/2020-12/schema",
- "title": "Customer",
- "type": "object",
- "properties": {
- "name": {"type": "string"},
- "last_name": {"type": "string"},
- "address": {"$ref": "customer#/$defs/address"},
- },
- "required": [
- "name",
- "first_name",
- "last_name",
- "address",
- "shipping_address",
- "billing_address",
- ],
- "$defs": {
- "address": {
- "title": "Address",
- "$schema": "http://json-schema.org/draft-07/schema#",
- "type": "object",
- "properties": {
- "city": {"type": "string"},
- },
- "required": ["street_address", "city", "state"],
- "definitions": {
- "state": {
- "type": "object",
- "title": "State",
- "properties": {"name": {"type": "string"}},
- "required": ["name"],
- }
- },
- }
- },
- },
- f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}',
- [
- (
- '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}',
- True,
- )
- ],
- ),
- # Optional properties
- # Last required property in first position
- (
- {
- "properties": {
- "name": {"type": "string"},
- "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
- "weapon": {"anyOf": [{"type": "string"}, {"type": "null"}]},
- },
- "required": ["name"],
- "title": "Character",
- "type": "object",
- },
- f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}',
- [
- ('{ "name" : "Player" }', True),
- ('{ "name" : "Player", "weapon" : "sword" }', True),
- ('{ "age" : 10, "weapon" : "sword" }', False),
- ],
- ),
- # Last required property in middle position
- (
- {
- "properties": {
- "name": {"type": "string"},
- "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
- "weapon": {"type": "string"},
- "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
- },
- "required": ["name", "weapon"],
- "title": "Character",
- "type": "object",
- },
- f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}',
- [
- ('{ "name" : "Player" , "weapon" : "sword" }', True),
- (
- '{ "name" : "Player", "age" : 10, "weapon" : "sword" , "strength" : 10 }',
- True,
- ),
- ('{ "weapon" : "sword" }', False),
- ],
- ),
- # Last required property in last position
- (
- {
- "properties": {
- "name": {"anyOf": [{"type": "string"}, {"type": "null"}]},
- "age": {"type": "integer"},
- "armor": {"type": "string"},
- "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
- "weapon": {"title": "Weapon", "type": "string"},
- },
- "required": ["age", "armor", "weapon"],
- "title": "Character",
- "type": "object",
- },
- f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}',
- [
- (
- '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }',
- True,
- ),
- ('{ "age" : 10, "armor" : "plate", "weapon" : "sword" }', True),
- (
- '{ "name" : "Kahlhanbeh", "armor" : "plate", "weapon" : "sword" }',
- False,
- ),
- ],
- ),
- # All properties are optional
- (
- {
- "properties": {
- "name": {"anyOf": [{"type": "string"}, {"type": "null"}]},
- "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
- "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]},
- },
- "title": "Character",
- "type": "object",
- },
- f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}',
- [
- ('{ "name" : "Player" }', True),
- ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True),
- ('{ "age" : 10, "strength" : 10 }', True),
- ("{ }", True),
- ],
- ),
- ],
-)
-def test_match(schema, regex, examples):
- interegular.parse_pattern(regex)
- schema = json.dumps(schema)
- test_regex = build_regex_from_schema(schema)
- assert test_regex == regex
- for string, does_match in examples:
- match = re.fullmatch(test_regex, string)
- if does_match:
- if match is None:
- raise ValueError(f"Expected match for '{string}'")
- assert match[0] == string
- assert match.span() == (0, len(string))
- else:
- assert match is None
+# if you don't register your function as callable, you will get an empty enum
+class EmptyEnum(Enum):
+ add = add
@pytest.mark.parametrize(
- "schema,regex,examples",
+ "enum,expectation",
[
- # UUID
- (
- {"title": "Foo", "type": "string", "format": "uuid"},
- UUID,
- [
- ("123e4567-e89b-12d3-a456-426614174000", False),
- ('"123e4567-e89b-12d3-a456-426614174000"', True),
- ('"123e4567-e89b-12d3-a456-42661417400"', False),
- ('"123e4567-e89b-12d3-a456-42661417400g"', False),
- ('"123e4567-e89b-12d3-a456-42661417400-"', False),
- ('""', False),
- ],
- ),
- # DATE-TIME
- (
- {"title": "Foo", "type": "string", "format": "date-time"},
- DATE_TIME,
- [
- ("2018-11-13T20:20:39Z", False),
- ('"2018-11-13T20:20:39Z"', True),
- ('"2016-09-18T17:34:02.666Z"', True),
- ('"2008-05-11T15:30:00Z"', True),
- ('"2021-01-01T00:00:00"', True),
- ('"2022-01-10 07:19:30"', False), # missing T
- ('"2022-12-10T10-04-29"', False), # incorrect separator
- ('"2023-01-01"', False),
- ],
- ),
- # DATE
- (
- {"title": "Foo", "type": "string", "format": "date"},
- DATE,
- [
- ("2018-11-13", False),
- ('"2018-11-13"', True),
- ('"2016-09-18"', True),
- ('"2008-05-11"', True),
- ('"2015-13-01"', False), # incorrect month
- ('"2022-01"', False), # missing day
- ('"2022/12/01"', False), # incorrect separator"
- ],
- ),
- # TIME
- (
- {"title": "Foo", "type": "string", "format": "time"},
- TIME,
- [
- ("20:20:39Z", False),
- ('"20:20:39Z"', True),
- ('"15:30:00Z"', True),
- ('"25:30:00"', False), # incorrect hour
- ('"15:30"', False), # missing seconds
- ('"15:30:00.000"', False), # missing Z
- ('"15-30-00"', False), # incorrect separator
- ('"15:30:00+01:00"', False), # incorrect separator
- ],
- ),
+ (MyEnum, nullcontext()),
+ (EmptyEnum, pytest.raises(ValueError)),
],
)
-def test_format(schema, regex, examples):
- interegular.parse_pattern(regex)
- schema = json.dumps(schema)
- test_regex = build_regex_from_schema(schema)
- assert test_regex == regex
-
- for string, does_match in examples:
- match = re.fullmatch(test_regex, string)
- if does_match:
- assert match[0] == string
- assert match.span() == (0, len(string))
- else:
- assert match is None
-
-
-@pytest.mark.parametrize(
- "schema,examples",
- [
- # NESTED UUID
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {"uuid": {"type": "string", "format": "uuid"}},
- },
- [
- ('{"uuid": "123e4567-e89b-12d3-a456-426614174000"}', True),
- ('{"uuid":"123e4567-e89b-12d3-a456-42661417400"}', False),
- ('{"uuid":"123e4567-e89b-12d3-a456-42661417400g"}', False),
- ('{"uuid":"123e4567-e89b-12d3-a456-42661417400-"}', False),
- (
- '{"uuid":123e4567-e89b-12d3-a456-426614174000}',
- False,
- ), # missing quotes for value
- ('{"uuid":""}', False),
- ],
- ),
- # NESTED DATE-TIME
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {"dateTime": {"type": "string", "format": "date-time"}},
- },
- [
- ('{"dateTime": "2018-11-13T20:20:39Z"}', True),
- ('{"dateTime":"2016-09-18T17:34:02.666Z"}', True),
- ('{"dateTime":"2008-05-11T15:30:00Z"}', True),
- ('{"dateTime":"2021-01-01T00:00:00"}', True),
- ('{"dateTime":"2022-01-10 07:19:30"}', False), # missing T
- ('{"dateTime":"2022-12-10T10-04-29"}', False), # incorrect separator
- (
- '{"dateTime":2018-11-13T20:20:39Z}',
- False,
- ), # missing quotes for value
- ('{"dateTime":"2023-01-01"}', False),
- ],
- ),
- # NESTED DATE
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {"date": {"type": "string", "format": "date"}},
- },
- [
- ('{"date": "2018-11-13"}', True),
- ('{"date":"2016-09-18"}', True),
- ('{"date":"2008-05-11"}', True),
- ('{"date":"2015-13-01"}', False), # incorrect month
- ('{"date":"2022-01"}', False), # missing day
- ('{"date":"2022/12/01"}', False), # incorrect separator"
- ('{"date":2018-11-13}', False), # missing quotes for value
- ],
- ),
- # NESTED TIME
- (
- {
- "title": "Foo",
- "type": "object",
- "properties": {"time": {"type": "string", "format": "time"}},
- },
- [
- ('{"time": "20:20:39Z"}', True),
- ('{"time":"15:30:00Z"}', True),
- ('{"time":"25:30:00"}', False), # incorrect hour
- ('{"time":"15:30"}', False), # missing seconds
- ('{"time":"15:30:00.000"}', False), # missing Z
- ('{"time":"15-30-00"}', False), # incorrect separator
- ('{"time":"15:30:00+01:00"}', False), # incorrect separator
- ('{"time":20:20:39Z}', False), # missing quotes for value
- ],
- ),
- # Unconstrained Object
- (
- {
- "title": "Foo",
- "type": "object",
- },
- [
- ("{}", True),
- ('{"a": 1, "b": null}', True),
- ('{"a": {"z": {"g": 4}}, "b": null}', True),
- ("1234", False), # not an object
- ('["a", "a"]', False), # not an array
- ],
- ),
- # Unconstrained Array
- (
- {
- "type": "array",
- },
- [
- ("[1, {}, false]", True),
- ("[{}]", True),
- ('[{"a": {"z": "q"}, "b": null}]', True),
- ('[{"a": [1, 2, true], "b": null}]', True),
- ('[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2]]]', True),
- # too deep, default unconstrained depth limit = 2
- (
- '[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2, [3]]]]',
- False,
- ),
- ('[{"a": {"z": {"g": 4}}, "b": null}]', False),
- ("[[[[1]]]]", False),
- # not an array
- ("{}", False),
- ('{"a": 1, "b": null}', False),
- ('{"a": {"z": {"g": 4}}, "b": null}', False),
- ("1234", False), # not an array
- ('{"a": "a"}', False), # not an array
- ],
- ),
- # No schema / unconstrained value
- (
- {},
- [
- ('"aaabbuecuh"', True), # string
- ("5.554", True), # number
- ("true", True), # boolean
- ("null", True), # null
- ("5999", True), # integer
- ('["a", "b"]', True), # array
- ('{"key": {"k2": "value"}}', True), # nested object
- ("this isnt valid json", False),
- ],
- ),
- ],
-)
-def test_format_without_regex(schema, examples):
- schema = json.dumps(schema)
- test_regex = build_regex_from_schema(schema)
- for string, does_match in examples:
- match = re.fullmatch(test_regex, string)
- if does_match:
- assert match[0] == string
- assert match.span() == (0, len(string))
- else:
- assert match is None
-
-
-@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"])
-def test_json_schema_custom_whitespace_pattern(whitespace_pattern):
- """assert whitespace_pattern setting respected"""
-
- class MockModel(BaseModel):
- foo: int
- bar: str
-
- schema = json.dumps(MockModel.model_json_schema())
-
- # assert any ws pattern can be used
- if whitespace_pattern == "abc":
- build_regex_from_schema(schema, whitespace_pattern)
- return
-
- pattern = build_regex_from_schema(schema, whitespace_pattern)
-
- mock_result_mult_ws = (
- """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}"""
- )
- mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}"""
-
- match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws)
- if whitespace_pattern is None:
- assert match_default_ws
- else:
- assert re.fullmatch(pattern, mock_result_mult_ws)
-
-
-def test_one_of_doesnt_produce_illegal_lookaround():
- """Reproduces failure in https://github.com/dottxt-ai/outlines/issues/823"""
-
- class Cat(BaseModel):
- pet_type: Literal["cat"]
- meows: int
-
- class Dog(BaseModel):
- pet_type: Literal["dog"]
- barks: float
-
- class Model(BaseModel):
- pet: Union[Cat, Dog] = Field(..., discriminator="pet_type")
- n: int
-
- json_schema = Model.schema_json()
-
- json_schema = Model.schema_json()
- pattern = build_regex_from_schema(json_schema, whitespace_pattern=None)
-
- # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm()
- interegular.parse_pattern(pattern).to_fsm()
+def test_enum_schema(enum, expectation):
+ with expectation:
+ schema = get_schema_from_enum(enum)
+ regex_str = build_regex_from_schema(json.dumps(schema))
+ assert isinstance(regex_str, str)
+ assert schema["title"] == enum.__name__
+ assert len(schema["oneOf"]) == len(enum)
+ for elt in schema["oneOf"]:
+ assert type(elt) in [int, float, bool, type(None), str, dict]
diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py
index 9c288c21e..f91bc8653 100644
--- a/tests/generate/test_generate.py
+++ b/tests/generate/test_generate.py
@@ -1,5 +1,6 @@
import contextlib
import re
+from enum import Enum
import pytest
@@ -127,6 +128,18 @@ def model_t5(tmp_path_factory):
)
+class MyEnum(Enum):
+ foo = "foo"
+ bar = "bar"
+ baz = "baz"
+
+
+ALL_SAMPLE_CHOICES_FIXTURES = (
+ ["foo", "bar", "baz"],
+ MyEnum,
+)
+
+
##########################################
# Stuctured Generation Inputs
##########################################
@@ -264,21 +277,33 @@ def test_generate_json(request, model_fixture, sample_schema):
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
+@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES)
def test_generate_choice(request, model_fixture, sample_choices):
model = request.getfixturevalue(model_fixture)
generator = generate.choice(model, sample_choices)
res = generator(**get_inputs(model_fixture))
- assert res in sample_choices
+ if isinstance(sample_choices, type(Enum)):
+ assert res in [elt.value for elt in sample_choices]
+ else:
+ assert res in sample_choices
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
+@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES)
def test_generate_choice_twice(request, model_fixture, sample_choices):
model = request.getfixturevalue(model_fixture)
generator = generate.choice(model, sample_choices)
res = generator(**get_inputs(model_fixture))
- assert res in sample_choices
+ if isinstance(sample_choices, type(Enum)):
+ assert res in [elt.value for elt in sample_choices]
+ else:
+ assert res in sample_choices
+
res = generator(**get_inputs(model_fixture))
- assert res in sample_choices
+ if isinstance(sample_choices, type(Enum)):
+ assert res in [elt.value for elt in sample_choices]
+ else:
+ assert res in sample_choices
@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py
index 8d4596d60..fd5be2171 100644
--- a/tests/generate/test_integration_llamacpp.py
+++ b/tests/generate/test_integration_llamacpp.py
@@ -274,6 +274,7 @@ def test_llama_cpp_pre_tokenizer_remains_broken():
generate.choice(model, ["skirt", "dress", "pen", "jacket"])
+@pytest.mark.skip("Caching for guide was temporarily turned off")
def test_RegexGuide_caching(model, temp_cache_dir):
import llama_cpp
diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py
index 2462d9fcf..92c5d789c 100644
--- a/tests/generate/test_integration_transformers.py
+++ b/tests/generate/test_integration_transformers.py
@@ -1,6 +1,7 @@
import datetime
import re
from enum import Enum
+from functools import partial
from typing import List, Union
import pytest
@@ -354,6 +355,29 @@ class User(BaseModel):
assert result.user_id in [1, 2]
+def add(a: int, b: int) -> int:
+ return a + b
+
+
+def mul(c: float, d: float) -> float:
+ return c * d
+
+
+def test_transformers_json_function_enum(model):
+ prompt = "Output some JSON "
+
+ class Operation(Enum):
+ add = partial(add)
+ mul = partial(mul)
+
+ result = generate.json(model, Operation)(prompt, seed=0)
+ assert isinstance(result, dict)
+ assert len(result) == 2
+ for k, v in result.items():
+ assert k in ["a", "b", "c", "d"]
+ assert isinstance(v, (int, float))
+
+
def test_transformers_json_array(model):
prompt = "Output some JSON "
@@ -492,6 +516,7 @@ def test_transformers_use_existing_model_and_tokenizer():
assert isinstance(sequence, str)
+@pytest.mark.skip("Caching for guide was temporarily turned off")
def test_RegexGuide_caching(temp_cache_dir):
import outlines.caching
from outlines.fsm.guide import cached_create_states_mapping
diff --git a/tests/models/test_mlxlm.py b/tests/models/test_mlxlm.py
new file mode 100644
index 000000000..20e59da81
--- /dev/null
+++ b/tests/models/test_mlxlm.py
@@ -0,0 +1,100 @@
+import pytest
+
+from outlines.models.mlxlm import mlxlm
+from outlines.models.transformers import TransformerTokenizer
+
+try:
+ import mlx.core as mx
+
+ HAS_MLX = mx.metal.is_available()
+except ImportError:
+ HAS_MLX = False
+
+
+TEST_MODEL = "mlx-community/SmolLM-135M-Instruct-4bit"
+
+
+@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
+def test_mlxlm_model():
+ model = mlxlm(TEST_MODEL)
+ assert hasattr(model, "model")
+ assert hasattr(model, "tokenizer")
+ assert isinstance(model.tokenizer, TransformerTokenizer)
+
+
+@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
+def test_mlxlm_tokenizer():
+ model = mlxlm(TEST_MODEL)
+
+ # Test single string encoding/decoding
+ test_text = "Hello, world!"
+ token_ids = mx.array(model.mlx_tokenizer.encode(test_text))
+ assert isinstance(token_ids, mx.array)
+
+
+@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
+def test_mlxlm_generate():
+ from outlines.generate.api import GenerationParameters, SamplingParameters
+
+ model = mlxlm(TEST_MODEL)
+ prompt = "Write a haiku about programming:"
+
+ # Test with basic generation parameters
+ gen_params = GenerationParameters(max_tokens=50, stop_at=None, seed=None)
+
+ # Test with different sampling parameters
+ sampling_params = SamplingParameters(
+ sampler="multinomial", num_samples=1, top_p=0.9, top_k=None, temperature=0.7
+ )
+
+ # Test generation
+ output = model.generate(prompt, gen_params, None, sampling_params)
+ assert isinstance(output, str)
+ assert len(output) > 0
+
+
+@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
+def test_mlxlm_stream():
+ from outlines.generate.api import GenerationParameters, SamplingParameters
+
+ model = mlxlm(TEST_MODEL)
+ prompt = "Count from 1 to 5:"
+
+ gen_params = GenerationParameters(max_tokens=20, stop_at=None, seed=None)
+
+ sampling_params = SamplingParameters(
+ sampler="greedy", # Use greedy sampling for deterministic output
+ num_samples=1,
+ top_p=None,
+ top_k=None,
+ temperature=0.0,
+ )
+
+ # Test streaming
+ stream = model.stream(prompt, gen_params, None, sampling_params)
+ tokens = list(stream)
+ assert len(tokens) > 0
+ assert all(isinstance(token, str) for token in tokens)
+
+ # Test that concatenated streaming output matches generate output
+ streamed_text = "".join(tokens)
+ generated_text = model.generate(prompt, gen_params, None, sampling_params)
+ assert streamed_text == generated_text
+
+
+@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon")
+def test_mlxlm_errors():
+ model = mlxlm(TEST_MODEL)
+
+ # Test batch inference (should raise NotImplementedError)
+ with pytest.raises(NotImplementedError):
+ from outlines.generate.api import GenerationParameters, SamplingParameters
+
+ gen_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None)
+ sampling_params = SamplingParameters("multinomial", 1, None, None, 1.0)
+ model.generate(["prompt1", "prompt2"], gen_params, None, sampling_params)
+
+ # Test beam search (should raise NotImplementedError)
+ with pytest.raises(NotImplementedError):
+ sampling_params = SamplingParameters("beam_search", 1, None, None, 1.0)
+ model.generate("test prompt", gen_params, None, sampling_params)
diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py
new file mode 100644
index 000000000..cd9f48278
--- /dev/null
+++ b/tests/processors/test_base_processor.py
@@ -0,0 +1,74 @@
+from typing import List
+
+import jax.numpy as jnp
+import numpy as np
+import pytest
+import torch
+
+from outlines.processors.base_logits_processor import OutlinesLogitsProcessor
+
+arrays = {
+ "list": [[1.0, 2.0], [3.0, 4.0]],
+ "np": np.array([[1, 2], [3, 4]], dtype=np.float32),
+ "jax": jnp.array([[1, 2], [3, 4]], dtype=jnp.float32),
+ "torch": torch.tensor([[1, 2], [3, 4]], dtype=torch.float32),
+}
+
+try:
+ import mlx.core as mx
+
+ arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32)
+except ImportError:
+ pass
+
+try:
+ import jax.numpy as jnp
+
+ arrays["jax"] = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32)
+except ImportError:
+ pass
+
+
+# Mock implementation of the abstract class for testing
+class MockLogitsProcessor(OutlinesLogitsProcessor):
+ def process_logits(
+ self, input_ids: List[List[int]], logits: torch.Tensor
+ ) -> torch.Tensor:
+ # For testing purposes, let's just return logits multiplied by 2
+ return logits * 2
+
+
+@pytest.fixture
+def processor():
+ """Fixture for creating an instance of the MockLogitsProcessor."""
+ return MockLogitsProcessor()
+
+
+@pytest.mark.parametrize("array_type", arrays.keys())
+def test_to_torch(array_type, processor):
+ data = arrays[array_type]
+ torch_tensor = processor._to_torch(data)
+ assert isinstance(torch_tensor, torch.Tensor)
+ assert torch.allclose(
+ torch_tensor.cpu(), torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
+ )
+
+
+@pytest.mark.parametrize("array_type", arrays.keys())
+def test_from_torch(array_type, processor):
+ torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32)
+ data = processor._from_torch(torch_tensor, type(arrays[array_type]))
+ assert isinstance(data, type(arrays[array_type]))
+ assert np.allclose(data, arrays[array_type])
+
+
+@pytest.mark.parametrize("array_type", arrays.keys())
+def test_call(array_type, processor):
+ input_ids = arrays[array_type]
+ logits = arrays[array_type]
+ processed_logits = processor(input_ids, logits)
+
+ assert isinstance(processed_logits, type(arrays[array_type]))
+ assert np.allclose(
+ np.array(processed_logits), np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32)
+ )