diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 36e6f8526..f54a888fa 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -44,7 +44,7 @@ jobs: echo "::set-output name=id::$MATRIX_ID" - name: Run tests run: | - pytest --cov=outlines + pytest -x --cov=outlines env: COVERAGE_FILE: .coverage.${{ steps.matrix-id.outputs.id }} - name: Upload coverage data diff --git a/.gitignore b/.gitignore index 4984b18cb..08390ae3d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ docs/build *.gguf .venv benchmarks/results +.python-version # Remove doc build folders .cache/ diff --git a/README.md b/README.md index 832703ce6..d34b0984d 100644 --- a/README.md +++ b/README.md @@ -2,17 +2,18 @@ Outlines Logo -[![.txt Twitter][dottxt-twitter-badge]][dottxt-twitter] + + 🗒️ *Make LLMs speak the language of every application.* 🗒️ + +Made with ❤👷️ by the team at [.txt](https://dottxt.co). [![Documentation][documentation-badge]][documentation] [![Contributors][contributors-badge]][contributors] [![Downloads][downloads-badge]][pypistats] [![Discord][discord-badge]][discord] +[Youtube channel][youtube-dottxt] | [.txt blog][blog-dottxt] | [Twitter][dottxt-twitter] -*Robust (structured) text generation.* - -Made with ❤👷️ by the team at [.txt](https://dottxt.co). @@ -83,6 +84,29 @@ generator = outlines.generate.choice(model, ["Positive", "Negative"]) answer = generator(prompt) ``` +You can also pass these choices through en enum: + +````python +from enum import Enum + +import outlines + +class Sentiment(str, Enum): + positive = "Positive" + negative = "Negative" + +model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct") + +prompt = """You are a sentiment-labelling assistant. +Is the following review positive or negative? + +Review: This restaurant is just awesome! +""" + +generator = outlines.generate.choice(model, Sentiment) +answer = generator(prompt) +```` + ### Type constraint You can instruct the model to only return integers or floats: @@ -190,7 +214,7 @@ character = generator("Give me a character description", seed=seed) print(repr(character)) # Character(name='Anderson', age=28, armor=, weapon=, strength=8) -character = generator("Give me an interesting character description", rng=rng) +character = generator("Give me an interesting character description") print(repr(character)) # Character(name='Vivian Thr', age=44, armor=, weapon=, strength=125) @@ -299,6 +323,33 @@ print(add(**result)) A great advantage of passing functions directly to specify the structure is that the structure of the LLM will change with the function's definition. No need to change the code at several places! +You can also embed various functions into an enum to generate params: + +```python +from enum import Enum +from functools import partial + +import outlines + + +def add(a: int, b: int) -> int: + return a + b + +def mul(c: float, d: float) -> float: + return c * d + +class Operation(Enum): + add = partial(add) + mul = partial(mul) + +model = outlines.models.transformers("WizardLM/WizardMath-7B-V1.1") +generator = outlines.generate.json(model, add) +result = generator("Return json with two float named c and d respectively. c is negative and d greater than 1.0.") + +print(result) +# {'c': -3.14, 'd': 1.5} +``` + ## Prompting Building prompts can get messy. **Outlines** makes it easier to write and manage @@ -363,3 +414,5 @@ answer = outlines.generate.text(model)(prompt, max_tokens=100) [downloads-badge]: https://img.shields.io/pypi/dm/outlines?color=89AC6B&logo=python&logoColor=white&style=flat-square [pypistats]: https://pypistats.org/packages/outlines [dottxt-twitter-badge]: https://img.shields.io/twitter/follow/dottxtai?style=social +[youtube-dottxt]: https://www.youtube.com/@dottxt-ai +[blog-dottxt]: https://blog.dottxt.co/ diff --git a/benchmarks/bench_json_schema.py b/benchmarks/bench_json_schema.py index 62d9b3c1d..3a1f72cb6 100644 --- a/benchmarks/bench_json_schema.py +++ b/benchmarks/bench_json_schema.py @@ -1,6 +1,7 @@ +from outlines_core.fsm.json_schema import build_regex_from_schema + from outlines.caching import cache_disabled from outlines.fsm.guide import RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema from .common import setup_tokenizer # noqa: E402 @@ -70,10 +71,6 @@ def setup(self, schema_name): self.tokenizer = setup_tokenizer() self.schema = schemas[schema_name] - @cache_disabled() - def time_json_schema_to_regex(self, schema_name): - build_regex_from_schema(self.schema) - @cache_disabled() def time_json_schema_to_fsm(self, schema_name): regex = build_regex_from_schema(self.schema) diff --git a/benchmarks/bench_processors.py b/benchmarks/bench_processors.py index 5b4901540..db1e4a8f1 100644 --- a/benchmarks/bench_processors.py +++ b/benchmarks/bench_processors.py @@ -9,6 +9,12 @@ except ImportError: pass +try: + import jax + import jax.numpy as jnp +except ImportError: + pass + def is_mlx_lm_allowed(): try: @@ -18,6 +24,14 @@ def is_mlx_lm_allowed(): return mx.metal.is_available() +def is_jax_allowed(): + try: + import jax # noqa: F401 + except ImportError: + return False + return True + + def get_mock_processor_inputs(array_library, num_tokens=30000): """ logits: (4, 30,000 ) dtype=float @@ -43,6 +57,13 @@ def get_mock_processor_inputs(array_library, num_tokens=30000): input_ids = mx.random.randint( low=0, high=num_tokens, shape=(4, 2048), dtype=mx.int32 ) + elif array_library == "jax": + logits = jnp.random.uniform( + key=jax.random.PRNGKey(0), shape=(4, num_tokens), dtype=jnp.float32 + ) + input_ids = jnp.random.randint( + key=jax.random.PRNGKey(0), low=0, high=num_tokens, shape=(4, 2048) + ) else: raise ValueError @@ -67,6 +88,8 @@ class LogitsProcessorPassthroughBenchmark: params += ["mlx"] if torch.cuda.is_available(): params += ["torch_cuda"] + if is_jax_allowed(): + params += ["jax"] def setup(self, array_library): self.logits_processor = HalvingLogitsProcessor() diff --git a/docs/cookbook/extract_event_details.md b/docs/cookbook/extract_event_details.md new file mode 100644 index 000000000..0f87d9586 --- /dev/null +++ b/docs/cookbook/extract_event_details.md @@ -0,0 +1,34 @@ +This recipe demonstrates how to use the `outlines` library to extract structured event details from a text message. +We will extract the title, location, and start date and time from messages like the following: + +```plaintext +Hello Kitty, my grandmother will be here, I think it's better to postpone +our appointment to review math lessons to next Monday at 2pm at the same +place, 3 avenue des tanneurs, one hour will be enough see you 😘 +``` + +Let see how to extract the event details from the message with the MLX +library dedicated to Apple Silicon processor (M series). + +```python +--8<-- "docs/cookbook/extract_event_details.py" +``` + +The output will be: + +```plaintext +Today: Saturday 16 November 2024 and it's 10:55 +``` + +and the extracted event information will be: + +```json +{ + "title":"Math Review", + "location":"3 avenue des tanneurs", + "start":"2024-11-22T14:00:00Z" +} +``` + + +To find out more about this use case, we recommend the project developped by [Joseph Rudoler](https://x.com/JRudoler) the [ICS Generator](https://github.com/jrudoler/ics-generator) diff --git a/docs/cookbook/extract_event_details.py b/docs/cookbook/extract_event_details.py new file mode 100644 index 000000000..b51f8d921 --- /dev/null +++ b/docs/cookbook/extract_event_details.py @@ -0,0 +1,46 @@ +from datetime import datetime + +from pydantic import BaseModel, Field + +from outlines import generate, models + +# Load the model +model = models.mlxlm("mlx-community/Hermes-3-Llama-3.1-8B-8bit") + + +# Define the event schema using Pydantic +class Event(BaseModel): + title: str = Field(description="title of the event") + location: str + start: datetime = Field( + default=None, description="date of the event if available in iso format" + ) + + +# Get the current date and time +now = datetime.now().strftime("%A %d %B %Y and it's %H:%M") + +# Define the prompt +prompt = f""" +Today's date and time are {now} +Given a user message, extract information of the event like date and time in iso format, location and title. +If the given date is relative, think step by step to find the right date. +Here is the message: +""" + +# Sample message +message = """Hello Kitty, my grandmother will be here , I think it's better to postpone our +appointment to review math lessons to next Friday at 2pm at the same place, 3 avenue des tanneurs, I think that one hour will be enough +see you 😘 """ + +# Create the generator +generator = generate.json(model, Event) + +# Extract the event information +event = generator(prompt + message) + +# Print the current date and time +print(f"Today: {now}") + +# Print the extracted event information in JSON format +print(event.json()) diff --git a/docs/cookbook/images/trader-joes-receipt.jpg b/docs/cookbook/images/trader-joes-receipt.jpg new file mode 100644 index 000000000..6742d3946 Binary files /dev/null and b/docs/cookbook/images/trader-joes-receipt.jpg differ diff --git a/docs/cookbook/index.md b/docs/cookbook/index.md index b163feb62..c36b98969 100644 --- a/docs/cookbook/index.md +++ b/docs/cookbook/index.md @@ -12,4 +12,7 @@ This part of the documentation provides a few cookbooks that you can browse to g - [Knowledge Graph Generation](knowledge_graph_extraction.md): Generate a Knowledge Graph from unstructured text using JSON-structured generation. - [Chain Of Thought (CoT)](chain_of_thought.md): Generate a series of intermediate reasoning steps using regex-structured generation. - [ReAct Agent](react_agent.md): Build an agent with open weights models using regex-structured generation. +- [Earnings reports to CSV](earnings-reports.md): Extract data from earnings reports to CSV using regex-structured generation. - [Vision-Language Models](atomic_caption.md): Use Outlines with vision-language models for tasks like image captioning and visual reasoning. +- [Receipt Digitization](receipt-digitization.md): Extract information from a picture of a receipt using structured generation. +- [Structured Generation from PDFs](read-pdfs.md): Use Outlines with vision-language models to read PDFs and produce structured output. diff --git a/docs/cookbook/read-pdfs.md b/docs/cookbook/read-pdfs.md new file mode 100644 index 000000000..dbe4ccb02 --- /dev/null +++ b/docs/cookbook/read-pdfs.md @@ -0,0 +1,376 @@ +# PDF to structured output with vision language models + +A common task with language models is to ask language models questions about a PDF file. + +Typically, the output is unstructured text, i.e. "talking" to your PDF. + +In some cases, you may wish to extract structured information from the PDF, like tables, lists, citations, etc. + +PDFs are difficult to machine read. However, you can simply convert the PDF to images, and then use a vision language model to extract structured information from the images. + +This cookbook demonstrates how to + +1. Convert a PDF to a list of images +2. Use a vision language model to extract structured information from the images + +## Dependencies + +You'll need to install these dependencies: + +```bash +pip install outlines pillow transformers torch==2.4.0 pdf2image + +# Optional, but makes the output look nicer +pip install rich +``` + +## Import the necessary libraries + +```python +from PIL import Image +import outlines +import torch +from transformers import AutoProcessor +from pydantic import BaseModel +from typing import List, Optional +from pdf2image import convert_from_path +import os +from rich import print +import requests +``` + +## Choose a model + +We've tested this example with [Pixtral 12b](https://huggingface.co/mistral-community/pixtral-12b) and [Qwen2-VL-7B-Instruct](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct). + +To use Pixtral: + +```python +from transformers import LlavaForConditionalGeneration +model_name="mistral-community/pixtral-12b" +model_class=LlavaForConditionalGeneration +``` + +To use Qwen-2-VL: + +```python +from transformers import Qwen2VLForConditionalGeneration +model_name = "Qwen/Qwen2-VL-7B-Instruct" +model_class = Qwen2VLForConditionalGeneration +``` + +You can load your model into memory with: + +```python +# This loads the model into memory. On your first run, +# it will have to download the model, which might take a while. +model = outlines.models.transformers_vision( + model_name, + model_class=model_class, + model_kwargs={ + "device_map": "auto", + "torch_dtype": torch.bfloat16, + }, + processor_kwargs={ + "device": "auto", + }, +) +``` + +## Convert the PDF to images + +We'll use the `pdf2image` library to convert each page of the PDF to an image. + +`convert_pdf_to_images` is a convenience function that converts each page of the PDF to an image, and optionally saves the images to disk when `output_dir` is provided. + +Note: the `dpi` argument is important. It controls the resolution of the images. High DPI images are higher quality and may yield better results, +but they are also larger, slower to process, and require more memory. + +```python +from pdf2image import convert_from_path +from PIL import Image +import os +from typing import List, Optional + +def convert_pdf_to_images( + pdf_path: str, + output_dir: Optional[str] = None, + dpi: int = 120, + fmt: str = 'PNG' +) -> List[Image.Image]: + """ + Convert a PDF file to a list of PIL Image objects. + + Args: + pdf_path: Path to the PDF file + output_dir: Optional directory to save the images + dpi: Resolution for the conversion. High DPI is high quality, but also slow and memory intensive. + fmt: Output format (PNG recommended for quality) + + Returns: + List of PIL Image objects + """ + # Convert PDF to list of images + images = convert_from_path( + pdf_path, + dpi=dpi, + fmt=fmt + ) + + # Optionally save images + if output_dir: + os.makedirs(output_dir, exist_ok=True) + for i, image in enumerate(images): + image.save(os.path.join(output_dir, f'page_{i+1}.{fmt.lower()}')) + + return images +``` + +We're going to use the [Louf & Willard paper](https://arxiv.org/pdf/2307.09702) that described the method that Outlines uses for structured generation. + +To download the PDF, run: + +```python +# Download the PDF file +pdf_url = "https://arxiv.org/pdf/2307.09702" +response = requests.get(pdf_url) + +# Save the PDF locally +with open("louf-willard.pdf", "wb") as f: + f.write(response.content) +``` + +Now, we can convert the PDF to a list of images: + +```python +# Load the pdf +images = convert_pdf_to_images( + "louf-willard.pdf", + dpi=120, + output_dir="output_images" +) +``` + +## Extract structured information from the images + +The structured output you can extract is exactly the same as everywhere else in Outlines -- you can use regular expressions, JSON schemas, selecting from a list of options, etc. + +### Extracting data into JSON + +Suppose you wished to go through each page of the PDF, and extract the page description, key takeaways, and page number. + +You can do this by defining a JSON schema, and then using `outlines.generate.json` to extract the data. + +First, define the structure you want to extract: + +```python +class PageSummary(BaseModel): + description: str + key_takeaways: List[str] + page_number: int +``` + +Second, we need to set up the prompt. Adding special tokens can be tricky, so we use the transformers `AutoProcessor` to apply the special tokens for us. To do so, we specify a list of messages, where each message is a dictionary with a `role` and `content` key. + +Images are denoted with `type: "image"`, and text is denoted with `type: "text"`. + +```python +messages = [ + { + "role": "user", + "content": [ + # The text you're passing to the model -- + # this is where you do your standard prompting. + {"type": "text", "text": f""" + Describe the page in a way that is easy for a PhD student to understand. + + Return the information in the following JSON schema: + {PageSummary.model_json_schema()} + + Here is the page: + """ + }, + + # Don't need to pass in an image, since we do this + # when we call the generator function down below. + {"type": "image", "image": ""}, + ], + } +] + +# Convert the messages to the final prompt +processor = AutoProcessor.from_pretrained(model_name) +instruction = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) +``` + +Now we iterate through each image, and extract the structured information: + +```python +# Page summarizer function +page_summary_generator = outlines.generate.json(model, PageSummary) + +for image in images: + result = page_summary_generator(instruction, [image]) + print(result) +``` + +### Regular expressions to extract the arxiv paper identifier + +The [arXiv paper identifier](https://info.arxiv.org/help/arxiv_identifier.html) is a unique identifier for each paper. These identifiers have the format `arXiv:YYMM.NNNNN` (five end digits) or `arXiv:YYMM.NNNN` (four end digits). arXiv identifiers are typically watermarked on papers uploaded to arXiv. + +arXiv identifiers are optionally followed by a version number, i.e. `arXiv:YYMM.NNNNNvX`. + +We can use a regular expression to define this patter: + +```python +paper_regex = r'arXiv:\d{2}[01]\d\.\d{4,5}(v\d)?' +``` + +We can build an extractor function from the regex: + +```python +id_extractor = outlines.generate.regex(model, paper_regex) +``` + +Now, we can extract the arxiv paper identifier from the first image: + +```python +arxiv_instruction = processor.apply_chat_template( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": f""" + Extract the arxiv paper identifier from the page. + + Here is the page: + """}, + {"type": "image", "image": ""}, + ], + } + ], + tokenize=False, + add_generation_prompt=True +) + +# Extract the arxiv paper identifier +paper_id = id_extractor(arxiv_instruction, [images[0]]) +``` + +As of the time of this writing, the arxiv paper identifier is + +``` +arXiv:2307.09702v4 +``` + +Your version number may be different, but the part before `vX` should match. + +### Categorize the paper into one of several categories + +`outlines.generate.choice` allows the model to select one of several options. Suppose we wanted to categorize the paper into being about "language models", "economics", "structured generation", or "other". + +Let's define a few categories we might be interested in: + +```python +categories = [ + "llms", + "cell biology", + "other" +] +``` + +Now we can construct the prompt: + +```python +categorization_instruction = processor.apply_chat_template( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": f""" + Please choose one of the following categories + that best describes the paper. + + {categories} + + Here is the paper: + """}, + + {"type": "image", "image": ""}, + ], + } + ], + tokenize=False, + add_generation_prompt=True +) +``` + +Now we can show the model the first page and extract the category: + +```python +# Build the choice extractor +categorizer = outlines.generate.choice( + model, + categories +) + +# Categorize the paper +category = categorizer(categorization_instruction, [images[0]]) +print(category) +``` + +Which should return: + +``` +llms +``` + +## Additional notes + +You can provide multiple images to the model by + +1. Adding additional image messages +2. Providing a list of images to the `generate` function + +For example, to have two images, you can do: + +```python +two_image_prompt = processor.apply_chat_template( + [ + { + "role": "user", + "content": [ + {"type": "text", "text": "are both of these images of hot dogs?"}, + + # Tell the model there are two images + {"type": "image", "image": ""}, + {"type": "image", "image": ""}, + ], + } + ], + tokenize=False, + add_generation_prompt=True +) + +# Pass two images to the model +generator = outlines.generate.choice( + model, + ["hot dog", "not hot dog"] +) + +result = generator( + two_image_prompt, + + # Pass two images to the model + [images[0], images[1]] +) +print(result) +``` + +Using the first to pages of the paper (they are not images of hot dogs), we should get + +``` +not hot dog +``` diff --git a/docs/cookbook/receipt-digitization.md b/docs/cookbook/receipt-digitization.md new file mode 100644 index 000000000..67830fa81 --- /dev/null +++ b/docs/cookbook/receipt-digitization.md @@ -0,0 +1,296 @@ +# Receipt Data Extraction with VLMs + +## Setup + +You'll need to install the dependencies: + +```bash +pip install outlines torch==2.4.0 transformers accelerate pillow rich +``` + +## Import libraries + +Load all the necessary libraries: + +```python +# LLM stuff +import outlines +import torch +from transformers import AutoProcessor +from pydantic import BaseModel, Field +from typing import Literal, Optional, List + +# Image stuff +from PIL import Image +import requests + +# Rich for pretty printing +from rich import print +``` + +## Choose a model + +This example has been tested with `mistral-community/pixtral-12b` ([HF link](https://huggingface.co/mistral-community/pixtral-12b)) and `Qwen/Qwen2-VL-7B-Instruct` ([HF link](https://huggingface.co/Qwen/Qwen2-VL-7B-Instruct)). + +We recommend Qwen-2-VL as we have found it to be more accurate than Pixtral. + +If you want to use Qwen-2-VL, you can do the following: + +```python +# To use Qwen-2-VL: +from transformers import Qwen2VLForConditionalGeneration +model_name = "Qwen/Qwen2-VL-7B-Instruct" +model_class = Qwen2VLForConditionalGeneration +``` + +If you want to use Pixtral, you can do the following: + +```python +# To use Pixtral: +from transformers import LlavaForConditionalGeneration +model_name="mistral-community/pixtral-12b" +model_class=LlavaForConditionalGeneration +``` + +## Load the model + +Load the model into memory: + +```python +model = outlines.models.transformers_vision( + model_name, + model_class=model_class, + model_kwargs={ + "device_map": "auto", + "torch_dtype": torch.bfloat16, + }, + processor_kwargs={ + "device": "cuda", # set to "cpu" if you don't have a GPU + }, +) +``` + +## Image processing + +Images can be quite large. In GPU-poor environments, you may need to resize the image to a smaller size. + +Here's a helper function to do that: + +```python +def load_and_resize_image(image_path, max_size=1024): + """ + Load and resize an image while maintaining aspect ratio + + Args: + image_path: Path to the image file + max_size: Maximum dimension (width or height) of the output image + + Returns: + PIL Image: Resized image + """ + image = Image.open(image_path) + + # Get current dimensions + width, height = image.size + + # Calculate scaling factor + scale = min(max_size / width, max_size / height) + + # Only resize if image is larger than max_size + if scale < 1: + new_width = int(width * scale) + new_height = int(height * scale) + image = image.resize((new_width, new_height), Image.Resampling.LANCZOS) + + return image +``` + +You can change the resolution of the image by changing the `max_size` argument. Small max sizes will make the image more blurry, but processing will be faster and require less memory. + +## Load an image + +Load an image and resize it. We've provided a sample image of a Trader Joe's receipt, but you can use any image you'd like. + +Here's what the image looks like: + +![Trader Joe's receipt](./images/trader-joes-receipt.jpg) + +```python +# Path to the image +image_path = "https://dottxt-ai.github.io/outlines/main/cookbook/images/trader-joes-receipt.png" + +# Download the image +response = requests.get(image_path) +with open("receipt.png", "wb") as f: + f.write(response.content) + +# Load + resize the image +image = load_and_resize_image("receipt.png") +``` + +## Define the output structure + +We'll define a Pydantic model to describe the data we want to extract from the image. + +In our case, we want to extract the following information: + +- The store name +- The store address +- The store number +- A list of items, including the name, quantity, price per unit, and total price +- The tax +- The total +- The date +- The payment method + +Most fields are optional, as not all receipts contain all information. + +```python +class Item(BaseModel): + name: str + quantity: Optional[int] + price_per_unit: Optional[float] + total_price: Optional[float] + +class ReceiptSummary(BaseModel): + store_name: str + store_address: str + store_number: Optional[int] + items: List[Item] + tax: Optional[float] + total: Optional[float] + # Date is in the format YYYY-MM-DD. We can apply a regex pattern to ensure it's formatted correctly. + date: Optional[str] = Field(pattern=r'\d{4}-\d{2}-\d{2}', description="Date in the format YYYY-MM-DD") + payment_method: Literal["cash", "credit", "debit", "check", "other"] +``` + +## Prepare the prompt + +We'll use the `AutoProcessor` to convert the image and the text prompt into a format that the model can understand. Practically, +this is the code that adds user, system, assistant, and image tokens to the prompt. + +```python +# Set up the content you want to send to the model +messages = [ + { + "role": "user", + "content": [ + { + # The image is provided as a PIL Image object + "type": "image", + "image": image, + }, + { + "type": "text", + "text": f"""You are an expert at extracting information from receipts. + Please extract the information from the receipt. Be as detailed as possible -- + missing or misreporting information is a crime. + + Return the information in the following JSON schema: + {ReceiptSummary.model_json_schema()} + """}, + ], + } +] + +# Convert the messages to the final prompt +processor = AutoProcessor.from_pretrained(model_name) +prompt = processor.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True +) +``` + +If you are curious, the final prompt that is sent to the model looks (roughly) like this: + +``` +<|im_start|>system +You are a helpful assistant.<|im_end|> +<|im_start|>user +<|vision_start|><|image_pad|><|vision_end|> +You are an expert at extracting information from receipts. +Please extract the information from the receipt. Be as detailed as +possible -- missing or misreporting information is a crime. + +Return the information in the following JSON schema: + + +<|im_end|> +<|im_start|>assistant +``` + +## Run the model + +```python +# Prepare a function to process receipts +receipt_summary_generator = outlines.generate.json( + model, + ReceiptSummary, + + # Greedy sampling is a good idea for numeric + # data extraction -- no randomness. + sampler=outlines.samplers.greedy() +) + +# Generate the receipt summary +result = receipt_summary_generator(prompt, [image]) +print(result) +``` + +## Output + +The output should look like this: + +``` +ReceiptSummary( + store_name="Trader Joe's", + store_address='401 Bay Street, San Francisco, CA 94133', + store_number=0, + items=[ + Item(name='BANANA EACH', quantity=7, price_per_unit=0.23, total_price=1.61), + Item(name='BAREBELLS CHOCOLATE DOUG', quantity=1, price_per_unit=2.29, total_price=2.29), + Item(name='BAREBELLS CREAMY CRISP', quantity=1, price_per_unit=2.29, total_price=2.29), + Item(name='BAREBELLS CHOCOLATE DOUG', quantity=1, price_per_unit=2.29, total_price=2.29), + Item(name='BAREBELLS CARAMEL CASHEW', quantity=2, price_per_unit=2.29, total_price=4.58), + Item(name='BAREBELLS CREAMY CRISP', quantity=1, price_per_unit=2.29, total_price=2.29), + Item(name='SPINDRIFT ORANGE MANGO 8', quantity=1, price_per_unit=7.49, total_price=7.49), + Item(name='Bottle Deposit', quantity=8, price_per_unit=0.05, total_price=0.4), + Item(name='MILK ORGANIC GALLON WHOL', quantity=1, price_per_unit=6.79, total_price=6.79), + Item(name='CLASSIC GREEK SALAD', quantity=1, price_per_unit=3.49, total_price=3.49), + Item(name='COBB SALAD', quantity=1, price_per_unit=5.99, total_price=5.99), + Item(name='PEPPER BELL RED XL EACH', quantity=1, price_per_unit=1.29, total_price=1.29), + Item(name='BAG FEE.', quantity=1, price_per_unit=0.25, total_price=0.25), + Item(name='BAG FEE.', quantity=1, price_per_unit=0.25, total_price=0.25) + ], + tax=0.68, + total=41.98, + date='2023-11-04', + payment_method='debit', + +) +``` + +Voila! You've successfully extracted information from a receipt using an LLM. + +## Bonus: roasting the user for their receipt + +You can roast the user for their receipt by adding a `roast` field to the end of the `ReceiptSummary` model. + +```python +class ReceiptSummary(BaseModel): + ... + roast: str +``` + +which gives you a result like + +``` +ReceiptSummary( + ... + roast="You must be a fan of Trader Joe's because you bought enough + items to fill a small grocery bag and still had to pay for a bag fee. + Maybe you should start using reusable bags to save some money and the + environment." +) +``` + +Qwen is not particularly funny, but worth a shot. diff --git a/docs/overrides/home.html b/docs/overrides/home.html index 0c97d5ac6..6525ab062 100644 --- a/docs/overrides/home.html +++ b/docs/overrides/home.html @@ -7,27 +7,15 @@ {{ super() }}
@@ -89,26 +122,24 @@ Outlines Logo
-

+

Structured text generation and robust prompting for language models

- - - -

Made with ❤️ by the team at .txt

diff --git a/docs/reference/generation/generation.md b/docs/reference/generation/generation.md index a14818514..930ad9d22 100644 --- a/docs/reference/generation/generation.md +++ b/docs/reference/generation/generation.md @@ -4,7 +4,7 @@ title: Generation # Generation -Once an [Outlines model](../models) is constructed you can use `outlines.generate` to generate text. Standard LLM generation is possible via `outlines.generate.text`, along with a variety of structured generation methods described below. (For a detailed technical explanation of how structured generation works, you may review the [Structured Generation Explanation](./structured_generation_explanation.md) page) +Once an [Outlines model](../models/models.md) is constructed you can use `outlines.generate` to generate text. Standard LLM generation is possible via `outlines.generate.text`, along with a variety of structured generation methods described below. (For a detailed technical explanation of how structured generation works, you may review the [Structured Generation Explanation](./structured_generation_explanation.md) page) Before generating text, you must construct an `outlines.model`. Example: diff --git a/docs/reference/models/llamacpp.md b/docs/reference/models/llamacpp.md index 24b0fdc97..51b62eca8 100644 --- a/docs/reference/models/llamacpp.md +++ b/docs/reference/models/llamacpp.md @@ -4,7 +4,11 @@ Outlines provides an integration with [Llama.cpp](https://github.com/ggerganov/l !!! Note "Installation" - You need to install the `llama-cpp-python` library to use the llama.cpp integration. See the [installation section](#installation) for instructions to install `llama-cpp-python` with CUDA, Metal, ROCm and other backends. + You need to install the `llama-cpp-python` library to use the llama.cpp integration. See the [installation section](#installation) for instructions to install `llama-cpp-python` with CUDA, Metal, ROCm and other backends. To get started quickly you can also run: + + ```bash + pip install "outlines[llamacpp]" + ``` ## Load the model diff --git a/docs/reference/models/mlxlm.md b/docs/reference/models/mlxlm.md index cf7bb7443..d435b9c1f 100644 --- a/docs/reference/models/mlxlm.md +++ b/docs/reference/models/mlxlm.md @@ -4,7 +4,11 @@ Outlines provides an integration with [mlx-lm](https://github.com/ml-explore/mlx !!! Note "Installation" - You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration. + You need to install the `mlx` and `mlx-lm` libraries on a device which [supports Metal](https://support.apple.com/en-us/102894) to use the mlx-lm integration. To get started quickly you can also run: + + ```bash + pip install "outlines[mlxlm]" + ``` ## Load the model diff --git a/docs/reference/models/openai.md b/docs/reference/models/openai.md index 5c737c916..638107568 100644 --- a/docs/reference/models/openai.md +++ b/docs/reference/models/openai.md @@ -2,7 +2,11 @@ !!! Installation - You need to install the `openai` library to be able to use the OpenAI API in Outlines. + You need to install the `openai` library to be able to use the OpenAI API in Outlines. Or alternatively: + + ```bash + pip install "outlines[openai]" + ``` ## OpenAI models diff --git a/docs/reference/models/transformers.md b/docs/reference/models/transformers.md index 2a13e28ec..f4c319540 100644 --- a/docs/reference/models/transformers.md +++ b/docs/reference/models/transformers.md @@ -3,10 +3,10 @@ !!! Installation - You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines: + You need to install the `transformer`, `datasets` and `torch` libraries to be able to use these models in Outlines, or alternatively: ```bash - pip install torch transformers datasets + pip install "outlines[transformers]" ``` diff --git a/docs/reference/models/vllm.md b/docs/reference/models/vllm.md index fb1c830fa..8789b588e 100644 --- a/docs/reference/models/vllm.md +++ b/docs/reference/models/vllm.md @@ -3,7 +3,11 @@ !!! Note "Installation" - You need to install the `vllm` library to use the vLLM integration. See the [installation section](#installation) for instructions to install vLLM for CPU or ROCm. + You need to install the `vllm` library to use the vLLM integration. See the [installation section](#installation) for instructions to install vLLM for CPU or ROCm. To get started you can also run: + + ```bash + pip install "outlines[vllm]" + ``` ## Load the model diff --git a/docs/stylesheets/extra.css b/docs/stylesheets/extra.css index 33783c153..e71a6124f 100644 --- a/docs/stylesheets/extra.css +++ b/docs/stylesheets/extra.css @@ -112,6 +112,9 @@ h1.title { h2.subtitle { margin: 5px 0px 25px; + font-size: 1rem; + max-width: 540px; + margin: 0 auto; } .md-typeset { diff --git a/examples/beam-cloud/README.md b/examples/beam-cloud/README.md new file mode 100644 index 000000000..5f190d76f --- /dev/null +++ b/examples/beam-cloud/README.md @@ -0,0 +1,5 @@ +## Deploy Outlines on Beam + +1. Create an account [here](https://beam.cloud) and install the Beam SDK +2. Download the `app.py` file to your computer +3. Deploy it as a serverless API by running: `beam deploy app.py:predict` diff --git a/examples/beam-cloud/app.py b/examples/beam-cloud/app.py new file mode 100644 index 000000000..fb6c2cb2b --- /dev/null +++ b/examples/beam-cloud/app.py @@ -0,0 +1,39 @@ +from beam import Image, endpoint, env + +if env.is_remote(): + import outlines + + +# Pre-load models when the container first starts +def load_models(): + import outlines + + model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct") + return model + + +@endpoint( + name="outlines-serverless", + gpu="A10G", + cpu=1, + memory="16Gi", + on_start=load_models, + image=Image().add_python_packages( + ["outlines", "torch", "transformers", "accelerate"] + ), +) +def predict(context, **inputs): + default_prompt = """You are a sentiment-labelling assistant. + Is the following review positive or negative? + + Review: This restaurant is just awesome! + """ + + prompt = inputs.get("prompt", default_prompt) + + # Unpack cached model from context + model = context.on_start_value + # Inference + generator = outlines.generate.choice(model, ["Positive", "Negative"]) + answer = generator(prompt) + return {"answer": answer} diff --git a/mkdocs.yml b/mkdocs.yml index b97a2e41e..4888309b9 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -74,6 +74,7 @@ markdown_extensions: - pymdownx.emoji: emoji_index: !!python/name:material.extensions.emoji.twemoji emoji_generator: !!python/name:material.extensions.emoji.to_svg + - pymdownx.snippets: extra_css: @@ -120,7 +121,10 @@ nav: - Chain of Thought (CoT): cookbook/chain_of_thought.md - ReAct Agent: cookbook/react_agent.md - Vision-Language Models: cookbook/atomic_caption.md + - Structured Generation from PDFs: cookbook/read-pdfs.md - Earnings reports to CSV: cookbook/earnings-reports.md + - Digitizing receipts with vision models: cookbook/receipt-digitization.md + - Extract events details from text: cookbook/extract_event_details.md - Run on the cloud: - BentoML: cookbook/deploy-using-bentoml.md - Cerebrium: cookbook/deploy-using-cerebrium.md diff --git a/outlines/base.py b/outlines/base.py index 4de8ccf5a..29d42c54c 100644 --- a/outlines/base.py +++ b/outlines/base.py @@ -5,12 +5,23 @@ from typing import Callable, Optional import numpy as np -from numpy.lib.function_base import ( - _calculate_shapes, - _parse_gufunc_signature, - _parse_input_dimensions, - _update_dim_sizes, -) + +# Import required functions based on NumPy version +np_major_version = int(np.__version__.split(".")[0]) +if np_major_version >= 2: + from numpy.lib._function_base_impl import ( + _calculate_shapes, + _parse_gufunc_signature, + _parse_input_dimensions, + _update_dim_sizes, + ) +else: + from numpy.lib.function_base import ( + _calculate_shapes, + _parse_gufunc_signature, + _parse_input_dimensions, + _update_dim_sizes, + ) # Allow nested loops for running in notebook. We don't enable it globally as it # may interfere with other libraries that use asyncio. diff --git a/outlines/fsm/guide.py b/outlines/fsm/guide.py index bbd3f44ba..9df2e9d12 100644 --- a/outlines/fsm/guide.py +++ b/outlines/fsm/guide.py @@ -15,7 +15,6 @@ ) from outlines import grammars -from outlines.caching import cache from outlines.fsm.parsing import PartialLark, PartialParserState if TYPE_CHECKING: @@ -73,7 +72,6 @@ def copy(self): return self -@cache() def cached_create_states_mapping(regex_string, tokenizer, *args, **kwargs): return uncached_create_states_mapping(regex_string, tokenizer, *args, **kwargs) diff --git a/outlines/fsm/json_schema.py b/outlines/fsm/json_schema.py index 98d2de59c..578ee7626 100644 --- a/outlines/fsm/json_schema.py +++ b/outlines/fsm/json_schema.py @@ -1,89 +1,10 @@ import inspect import json -import re import warnings -from typing import Callable, Optional, Tuple, Type, Union +from enum import Enum +from typing import Callable, Type, Union -from jsonschema.protocols import Validator from pydantic import BaseModel, create_model -from referencing import Registry, Resource -from referencing._core import Resolver -from referencing.jsonschema import DRAFT202012 - -# allow `\"`, `\\`, or any character which isn't a control sequence -STRING_INNER = r'([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])' -STRING = f'"{STRING_INNER}*"' - -INTEGER = r"(-)?(0|[1-9][0-9]*)" -NUMBER = rf"({INTEGER})(\.[0-9]+)?([eE][+-][0-9]+)?" -BOOLEAN = r"(true|false)" -NULL = r"null" -WHITESPACE = r"[ ]?" - -type_to_regex = { - "string": STRING, - "integer": INTEGER, - "number": NUMBER, - "boolean": BOOLEAN, - "null": NULL, -} - -DATE_TIME = r'"(-?(?:[1-9][0-9]*)?[0-9]{4})-(1[0-2]|0[1-9])-(3[01]|0[1-9]|[12][0-9])T(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\.[0-9]{3})?(Z)?"' -DATE = r'"(?:\d{4})-(?:0[1-9]|1[0-2])-(?:0[1-9]|[1-2][0-9]|3[0-1])"' -TIME = r'"(2[0-3]|[01][0-9]):([0-5][0-9]):([0-5][0-9])(\\.[0-9]+)?(Z)?"' -UUID = r'"[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}"' - -format_to_regex = { - "uuid": UUID, - "date-time": DATE_TIME, - "date": DATE, - "time": TIME, -} - - -def build_regex_from_schema(schema: str, whitespace_pattern: Optional[str] = None): - """Turn a JSON schema into a regex that matches any JSON object that follows - this schema. - - JSON Schema is a declarative language that allows to annotate JSON documents - with types and descriptions. These schemas can be generated from any Python - datastructure that has type annotation: namedtuples, dataclasses, Pydantic - models. And by ensuring that the generation respects the schema we ensure - that the output can be parsed into these objects. - This function parses the provided schema and builds a generation schedule which - mixes deterministic generation (fixed strings), and sampling with constraints. - - Parameters - ---------- - schema - A string that represents a JSON Schema. - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact string literals) - Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` - - Returns - ------- - A generation schedule. A list of strings that represent the JSON - schema's structure and regular expression that define the structure of - the fields. - - References - ---------- - .. [0] JSON Schema. https://json-schema.org/ - - """ - - schema = json.loads(schema) - Validator.check_schema(schema) - - # Build reference resolver - schema = Resource(contents=schema, specification=DRAFT202012) - uri = schema.id() if schema.id() is not None else "" - registry = Registry().with_resource(uri=uri, resource=schema) - resolver = registry.resolver() - - content = schema.contents - return to_regex(resolver, content, whitespace_pattern) def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) -> str: @@ -119,412 +40,7 @@ def convert_json_schema_to_str(json_schema: Union[dict, str, Type[BaseModel]]) - return schema_str -def _get_num_items_pattern(min_items, max_items, whitespace_pattern): - # Helper function for arrays and objects - min_items = int(min_items or 0) - if max_items is None: - return rf"{{{max(min_items - 1, 0)},}}" - else: - max_items = int(max_items) - if max_items < 1: - return None - return rf"{{{max(min_items - 1, 0)},{max_items - 1}}}" - - -def validate_quantifiers( - min_bound: Optional[str], max_bound: Optional[str], start_offset: int = 0 -) -> Tuple[str, str]: - """ - Ensures that the bounds of a number are valid. Bounds are used as quantifiers in the regex. - - Parameters - ---------- - min_bound - The minimum value that the number can take. - max_bound - The maximum value that the number can take. - start_offset - Number of elements that are already present in the regex but still need to be counted. - ex: if the regex is already "(-)?(0|[1-9][0-9])", we will always have at least 1 digit, so the start_offset is 1. - - Returns - ------- - min_bound - The minimum value that the number can take. - max_bound - The maximum value that the number can take. - - Raises - ------ - ValueError - If the minimum bound is greater than the maximum bound. - - TypeError or ValueError - If the minimum bound is not an integer or None. - or - If the maximum bound is not an integer or None. - """ - min_bound = "" if min_bound is None else str(int(min_bound) - start_offset) - max_bound = "" if max_bound is None else str(int(max_bound) - start_offset) - if min_bound and max_bound: - if int(max_bound) < int(min_bound): - raise ValueError("max bound must be greater than or equal to min bound") - return min_bound, max_bound - - -def to_regex( - resolver: Resolver, instance: dict, whitespace_pattern: Optional[str] = None -): - """Translate a JSON Schema instance into a regex that validates the schema. - - Note - ---- - Many features of JSON schema are missing: - - Handle `additionalProperties` keyword - - Handle types defined as a list - - Handle constraints on numbers - - Handle special patterns: `date`, `uri`, etc. - - This does not support recursive definitions. - - Parameters - ---------- - resolver - An object that resolves references to other instances within a schema - instance - The instance to translate - whitespace_pattern - Pattern to use for JSON syntactic whitespace (doesn't impact string literals) - Example: allow only a single space or newline with `whitespace_pattern=r"[\n ]?"` - """ - - # set whitespace pattern - if whitespace_pattern is None: - whitespace_pattern = WHITESPACE - - if instance == {}: - # JSON Schema Spec: Empty object means unconstrained, any json type is legal - types = [ - {"type": "boolean"}, - {"type": "null"}, - {"type": "number"}, - {"type": "integer"}, - {"type": "string"}, - {"type": "array"}, - {"type": "object"}, - ] - regexes = [to_regex(resolver, t, whitespace_pattern) for t in types] - regexes = [rf"({r})" for r in regexes] - return rf"{'|'.join(regexes)}" - - elif "properties" in instance: - regex = "" - regex += r"\{" - properties = instance["properties"] - required_properties = instance.get("required", []) - is_required = [item in required_properties for item in properties] - # If at least one property is required, we include the one in the lastest position - # without any comma. - # For each property before it (optional or required), we add with a comma after the property. - # For each property after it (optional), we add with a comma before the property. - if any(is_required): - last_required_pos = max([i for i, value in enumerate(is_required) if value]) - for i, (name, value) in enumerate(properties.items()): - subregex = f'{whitespace_pattern}"{re.escape(name)}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) - if i < last_required_pos: - subregex = f"{subregex}{whitespace_pattern}," - elif i > last_required_pos: - subregex = f"{whitespace_pattern},{subregex}" - regex += subregex if is_required[i] else f"({subregex})?" - # If no property is required, we have to create a possible pattern for each property in which - # it's the last one necessarilly present. Then, we add the others as optional before and after - # following the same strategy as described above. - # The whole block is made optional to allow the case in which no property is returned. - else: - property_subregexes = [] - for i, (name, value) in enumerate(properties.items()): - subregex = f'{whitespace_pattern}"{name}"{whitespace_pattern}:{whitespace_pattern}' - subregex += to_regex(resolver, value, whitespace_pattern) - property_subregexes.append(subregex) - possible_patterns = [] - for i in range(len(property_subregexes)): - pattern = "" - for subregex in property_subregexes[:i]: - pattern += f"({subregex}{whitespace_pattern},)?" - pattern += property_subregexes[i] - for subregex in property_subregexes[i + 1 :]: - pattern += f"({whitespace_pattern},{subregex})?" - possible_patterns.append(pattern) - regex += f"({'|'.join(possible_patterns)})?" - - regex += f"{whitespace_pattern}" + r"\}" - - return regex - - # To validate against allOf, the given data must be valid against all of the - # given subschemas. - elif "allOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["allOf"] - ] - subregexes_str = [f"{subregex}" for subregex in subregexes] - return rf"({''.join(subregexes_str)})" - - # To validate against `anyOf`, the given data must be valid against - # any (one or more) of the given subschemas. - elif "anyOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["anyOf"] - ] - return rf"({'|'.join(subregexes)})" - - # To validate against oneOf, the given data must be valid against exactly - # one of the given subschemas. - elif "oneOf" in instance: - subregexes = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["oneOf"] - ] - - xor_patterns = [f"(?:{subregex})" for subregex in subregexes] - - return rf"({'|'.join(xor_patterns)})" - - # Create pattern for Tuples, per JSON Schema spec, `prefixItems` determines types at each idx - elif "prefixItems" in instance: - element_patterns = [ - to_regex(resolver, t, whitespace_pattern) for t in instance["prefixItems"] - ] - comma_split_pattern = rf"{whitespace_pattern},{whitespace_pattern}" - tuple_inner = comma_split_pattern.join(element_patterns) - return rf"\[{whitespace_pattern}{tuple_inner}{whitespace_pattern}\]" - - # The enum keyword is used to restrict a value to a fixed set of values. It - # must be an array with at least one element, where each element is unique. - elif "enum" in instance: - choices = [] - for choice in instance["enum"]: - if type(choice) in [int, float, bool, type(None), str]: - choices.append(re.escape(json.dumps(choice))) - else: - raise TypeError(f"Unsupported data type in enum: {type(choice)}") - return f"({'|'.join(choices)})" - - elif "const" in instance: - const = instance["const"] - if type(const) in [int, float, bool, type(None), str]: - const = re.escape(json.dumps(const)) - else: - raise TypeError(f"Unsupported data type in const: {type(const)}") - return const - - elif "$ref" in instance: - path = f"{instance['$ref']}" - instance = resolver.lookup(path).contents - return to_regex(resolver, instance, whitespace_pattern) - - # The type keyword may either be a string or an array: - # - If it's a string, it is the name of one of the basic types. - # - If it is an array, it must be an array of strings, where each string is - # the name of one of the basic types, and each element is unique. In this - # case, the JSON snippet is valid if it matches any of the given types. - elif "type" in instance: - instance_type = instance["type"] - if instance_type == "string": - if "maxLength" in instance or "minLength" in instance: - max_items = instance.get("maxLength", "") - min_items = instance.get("minLength", "") - try: - if int(max_items) < int(min_items): - raise ValueError( - "maxLength must be greater than or equal to minLength" - ) # FIXME this raises an error but is caught right away by the except (meant for int("") I assume) - except ValueError: - pass - return f'"{STRING_INNER}{{{min_items},{max_items}}}"' - elif "pattern" in instance: - pattern = instance["pattern"] - if pattern[0] == "^" and pattern[-1] == "$": - return rf'("{pattern[1:-1]}")' - else: - return rf'("{pattern}")' - elif "format" in instance: - format = instance["format"] - if format == "date-time": - return format_to_regex["date-time"] - elif format == "uuid": - return format_to_regex["uuid"] - elif format == "date": - return format_to_regex["date"] - elif format == "time": - return format_to_regex["time"] - else: - raise NotImplementedError( - f"Format {format} is not supported by Outlines" - ) - else: - return type_to_regex["string"] - - elif instance_type == "number": - bounds = { - "minDigitsInteger", - "maxDigitsInteger", - "minDigitsFraction", - "maxDigitsFraction", - "minDigitsExponent", - "maxDigitsExponent", - } - if bounds.intersection(set(instance.keys())): - min_digits_integer, max_digits_integer = validate_quantifiers( - instance.get("minDigitsInteger"), - instance.get("maxDigitsInteger"), - start_offset=1, - ) - min_digits_fraction, max_digits_fraction = validate_quantifiers( - instance.get("minDigitsFraction"), instance.get("maxDigitsFraction") - ) - min_digits_exponent, max_digits_exponent = validate_quantifiers( - instance.get("minDigitsExponent"), instance.get("maxDigitsExponent") - ) - integers_quantifier = ( - f"{{{min_digits_integer},{max_digits_integer}}}" - if min_digits_integer or max_digits_integer - else "*" - ) - fraction_quantifier = ( - f"{{{min_digits_fraction},{max_digits_fraction}}}" - if min_digits_fraction or max_digits_fraction - else "+" - ) - exponent_quantifier = ( - f"{{{min_digits_exponent},{max_digits_exponent}}}" - if min_digits_exponent or max_digits_exponent - else "+" - ) - return rf"((-)?(0|[1-9][0-9]{integers_quantifier}))(\.[0-9]{fraction_quantifier})?([eE][+-][0-9]{exponent_quantifier})?" - return type_to_regex["number"] - - elif instance_type == "integer": - if "minDigits" in instance or "maxDigits" in instance: - min_digits, max_digits = validate_quantifiers( - instance.get("minDigits"), instance.get("maxDigits"), start_offset=1 - ) - return rf"(-)?(0|[1-9][0-9]{{{min_digits},{max_digits}}})" - return type_to_regex["integer"] - - elif instance_type == "array": - num_repeats = _get_num_items_pattern( - instance.get("minItems"), instance.get("maxItems"), whitespace_pattern - ) - if num_repeats is None: - return rf"\[{whitespace_pattern}\]" - - allow_empty = "?" if int(instance.get("minItems", 0)) == 0 else "" - - if "items" in instance: - items_regex = to_regex(resolver, instance["items"], whitespace_pattern) - return rf"\[{whitespace_pattern}(({items_regex})(,{whitespace_pattern}({items_regex})){num_repeats}){allow_empty}{whitespace_pattern}\]" - else: - # Here we need to make the choice to exclude generating list of objects - # if the specification of the object is not given, even though a JSON - # object that contains an object here would be valid under the specification. - legal_types = [ - {"type": "boolean"}, - {"type": "null"}, - {"type": "number"}, - {"type": "integer"}, - {"type": "string"}, - ] - depth = instance.get("depth", 2) - if depth > 0: - legal_types.append({"type": "object", "depth": depth - 1}) - legal_types.append({"type": "array", "depth": depth - 1}) - - regexes = [ - to_regex(resolver, t, whitespace_pattern) for t in legal_types - ] - return rf"\[{whitespace_pattern}({'|'.join(regexes)})(,{whitespace_pattern}({'|'.join(regexes)})){num_repeats}{allow_empty}{whitespace_pattern}\]" - - elif instance_type == "object": - # pattern for json object with values defined by instance["additionalProperties"] - # enforces value type constraints recursively, "minProperties", and "maxProperties" - # doesn't enforce "required", "dependencies", "propertyNames" "any/all/on Of" - num_repeats = _get_num_items_pattern( - instance.get("minProperties"), - instance.get("maxProperties"), - whitespace_pattern, - ) - if num_repeats is None: - return rf"\{{{whitespace_pattern}\}}" - - allow_empty = "?" if int(instance.get("minProperties", 0)) == 0 else "" - - additional_properties = instance.get("additionalProperties") - - if additional_properties is None or additional_properties is True: - # JSON Schema behavior: If the additionalProperties of an object is - # unset or True, it is unconstrained object. - # We handle this by setting additionalProperties to anyOf: {all types} - - legal_types = [ - {"type": "string"}, - {"type": "number"}, - {"type": "boolean"}, - {"type": "null"}, - ] - - # We set the object depth to 2 to keep the expression finite, but the "depth" - # key is not a true component of the JSON Schema specification. - depth = instance.get("depth", 2) - if depth > 0: - legal_types.append({"type": "object", "depth": depth - 1}) - legal_types.append({"type": "array", "depth": depth - 1}) - additional_properties = {"anyOf": legal_types} - - value_pattern = to_regex( - resolver, additional_properties, whitespace_pattern - ) - key_value_pattern = ( - f"{STRING}{whitespace_pattern}:{whitespace_pattern}{value_pattern}" - ) - key_value_successor_pattern = ( - f"{whitespace_pattern},{whitespace_pattern}{key_value_pattern}" - ) - multiple_key_value_pattern = f"({key_value_pattern}({key_value_successor_pattern}){num_repeats}){allow_empty}" - - return ( - r"\{" - + whitespace_pattern - + multiple_key_value_pattern - + whitespace_pattern - + r"\}" - ) - - elif instance_type == "boolean": - return type_to_regex["boolean"] - - elif instance_type == "null": - return type_to_regex["null"] - - elif isinstance(instance_type, list): - # Here we need to make the choice to exclude generating an object - # if the specification of the object is not give, even though a JSON - # object that contains an object here would be valid under the specification. - regexes = [ - to_regex(resolver, {"type": t}, whitespace_pattern) - for t in instance_type - if t != "object" - ] - return rf"({'|'.join(regexes)})" - - raise NotImplementedError( - f"""Could not translate the instance {instance} to a - regular expression. Make sure it is valid to the JSON Schema specification. If - it is, please open an issue on the Outlines repository""" - ) - - -def get_schema_from_signature(fn: Callable) -> str: +def get_schema_from_signature(fn: Callable) -> dict: """Turn a function signature into a JSON schema. Every JSON object valid to the output JSON Schema can be passed @@ -550,3 +66,18 @@ def get_schema_from_signature(fn: Callable) -> str: model = create_model(fn_name, **arguments) return model.model_json_schema() + + +def get_schema_from_enum(myenum: type[Enum]) -> dict: + if len(myenum) == 0: + raise ValueError( + f"Your enum class {myenum.__name__} has 0 members. If you are working with an enum of functions, do not forget to register them as callable (using `partial` for instance)" + ) + choices = [ + get_schema_from_signature(elt.value.func) + if callable(elt.value) + else {"const": elt.value} + for elt in myenum + ] + schema = {"title": myenum.__name__, "oneOf": choices} + return schema diff --git a/outlines/generate/choice.py b/outlines/generate/choice.py index 595513d52..afb998f52 100644 --- a/outlines/generate/choice.py +++ b/outlines/generate/choice.py @@ -1,7 +1,12 @@ import json as pyjson +import re +from enum import Enum from functools import singledispatch -from typing import Callable, List +from typing import Callable, List, Union +from outlines_core.fsm.json_schema import build_regex_from_schema + +from outlines.fsm.json_schema import get_schema_from_enum from outlines.generate.api import SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.samplers import Sampler, multinomial @@ -12,12 +17,19 @@ @singledispatch def choice( - model, choices: List[str], sampler: Sampler = multinomial() + model, choices: Union[List[str], type[Enum]], sampler: Sampler = multinomial() ) -> SequenceGeneratorAdapter: - regex_str = r"(" + r"|".join(choices) + r")" + if isinstance(choices, type(Enum)): + regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices))) + else: + choices = [re.escape(choice) for choice in choices] # type: ignore + regex_str = r"(" + r"|".join(choices) + r")" generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: x + if isinstance(choices, type(Enum)): + generator.format_sequence = lambda x: pyjson.loads(x) + else: + generator.format_sequence = lambda x: x return generator diff --git a/outlines/generate/json.py b/outlines/generate/json.py index f75878d29..d098d920d 100644 --- a/outlines/generate/json.py +++ b/outlines/generate/json.py @@ -1,10 +1,12 @@ import json as pyjson +from enum import Enum from functools import singledispatch from typing import Callable, Optional, Union +from outlines_core.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel -from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_signature +from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature from outlines.generate.api import SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.samplers import Sampler, multinomial @@ -48,6 +50,11 @@ def json( regex_str = build_regex_from_schema(schema, whitespace_pattern) generator = regex(model, regex_str, sampler) generator.format_sequence = lambda x: schema_object.parse_raw(x) + elif isinstance(schema_object, type(Enum)): + schema = pyjson.dumps(get_schema_from_enum(schema_object)) + regex_str = build_regex_from_schema(schema, whitespace_pattern) + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: pyjson.loads(x) elif callable(schema_object): schema = pyjson.dumps(get_schema_from_signature(schema_object)) regex_str = build_regex_from_schema(schema, whitespace_pattern) diff --git a/outlines/models/__init__.py b/outlines/models/__init__.py index d28fcb2d7..fe6f861ac 100644 --- a/outlines/models/__init__.py +++ b/outlines/models/__init__.py @@ -16,4 +16,4 @@ from .transformers_vision import TransformersVision, transformers_vision from .vllm import VLLM, vllm -LogitsGenerator = Union[Transformers, LlamaCpp, ExLlamaV2Model, MLXLM, VLLM] +LogitsGenerator = Union[Transformers, LlamaCpp, OpenAI, ExLlamaV2Model, MLXLM, VLLM] diff --git a/outlines/models/mlxlm.py b/outlines/models/mlxlm.py index 6e63ef5b6..d8b7e032c 100644 --- a/outlines/models/mlxlm.py +++ b/outlines/models/mlxlm.py @@ -167,12 +167,7 @@ def sample(logits: "mx.array") -> Tuple["mx.array", float]: prob = softmax_logits[0, token] return token, prob - kv_heads = ( - [self.model.n_kv_heads] * len(self.model.layers) - if isinstance(self.model.n_kv_heads, int) - else self.model.n_kv_heads - ) - cache = [mlx_lm.models.base.KVCache(self.model.head_dim, n) for n in kv_heads] + cache = mlx_lm.models.cache.make_prompt_cache(self.model) # kv cache contains processed input IDs, we pass the unprocessed inputs and cache to model() unprocessed_input_ids = prompt diff --git a/outlines/models/transformers.py b/outlines/models/transformers.py index 7ecc9013f..444492500 100644 --- a/outlines/models/transformers.py +++ b/outlines/models/transformers.py @@ -2,8 +2,6 @@ import inspect from typing import TYPE_CHECKING, Iterator, List, Optional, Tuple, Union -from datasets.fingerprint import Hasher - from outlines.generate.api import GenerationParameters, SamplingParameters from outlines.models.tokenizer import Tokenizer @@ -116,6 +114,8 @@ def __eq__(self, other): return NotImplemented def __hash__(self): + from datasets.fingerprint import Hasher + return hash(Hasher.hash(self.tokenizer)) def __getstate__(self): diff --git a/outlines/models/vllm.py b/outlines/models/vllm.py index d1f97bde2..778c27c6f 100644 --- a/outlines/models/vllm.py +++ b/outlines/models/vllm.py @@ -1,11 +1,10 @@ import dataclasses from typing import TYPE_CHECKING, List, Optional, Union -from transformers import SPIECE_UNDERLINE, PreTrainedTokenizerBase - from outlines.generate.api import GenerationParameters, SamplingParameters if TYPE_CHECKING: + from transformers import PreTrainedTokenizerBase from vllm import LLM from vllm.sampling_params import SamplingParams @@ -188,7 +187,7 @@ def vllm(model_name: str, **vllm_model_params): return VLLM(model) -def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBase: +def adapt_tokenizer(tokenizer: "PreTrainedTokenizerBase") -> "PreTrainedTokenizerBase": """Adapt a tokenizer to use to compile the FSM. The API of Outlines tokenizers is slightly different to that of `transformers`. In @@ -205,6 +204,8 @@ def adapt_tokenizer(tokenizer: PreTrainedTokenizerBase) -> PreTrainedTokenizerBa PreTrainedTokenizerBase The adapted tokenizer. """ + from transformers import SPIECE_UNDERLINE + tokenizer.vocabulary = tokenizer.get_vocab() tokenizer.special_tokens = set(tokenizer.all_special_tokens) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index 9a52abecd..44b55af2e 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -20,6 +20,16 @@ def is_mlx_array_type(array_type): return issubclass(array_type, mx.array) +def is_jax_array_type(array_type): + try: + import jaxlib + except ImportError: + return False + return issubclass(array_type, jaxlib.xla_extension.ArrayImpl) or isinstance( + array_type, jaxlib.xla_extension.ArrayImpl + ) + + class OutlinesLogitsProcessor(Protocol): """ Base class for logits processors which normalizes types of logits: @@ -97,9 +107,18 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: return torch.tensor(tensor_like) elif is_mlx_array_type(type(tensor_like)): - # mlx -> torch -> mlx conversion docs: - # https://ml-explore.github.io/mlx/build/html/usage/numpy.html - return torch.from_dlpack(tensor_like) + import mlx.core as mx + + # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch + return torch.from_dlpack( + np.array(tensor_like.astype(mx.float32), copy=False) + ) + + elif is_jax_array_type(type(tensor_like)): + import jax + + torch_tensor = torch.from_dlpack(jax.dlpack.to_dlpack(tensor_like)) + return torch_tensor else: raise TypeError( @@ -129,6 +148,11 @@ def _from_torch(tensor: torch.Tensor, target_type: Type) -> Array: # numpy doesn't support bfloat16, mlx doesn't support direct conversion from torch return mx.array(tensor.float().numpy()) + elif is_jax_array_type(target_type): + import jax + + return jax.dlpack.from_dlpack(tensor) + else: raise TypeError( f"Failed to convert torch tensors to target_type `{target_type}`" diff --git a/outlines/processors/structured.py b/outlines/processors/structured.py index 50ae6e3ee..6a0d4236e 100644 --- a/outlines/processors/structured.py +++ b/outlines/processors/structured.py @@ -27,10 +27,11 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type, Union import torch +from outlines_core.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel from outlines.fsm.guide import CFGGuide, Guide, RegexGuide -from outlines.fsm.json_schema import build_regex_from_schema, convert_json_schema_to_str +from outlines.fsm.json_schema import convert_json_schema_to_str from .base_logits_processor import OutlinesLogitsProcessor diff --git a/pyproject.toml b/pyproject.toml index fa7005afd..b83275f89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "jinja2", "lark", "nest_asyncio", - "numpy<2.0.0", + "numpy", "cloudpickle", "diskcache", "pydantic>=2.0", @@ -36,16 +36,21 @@ dependencies = [ "jsonschema", "requests", "tqdm", - "datasets", "typing_extensions", "pycountry", "airportsdata", "torch", - "outlines_core==0.1.14", + "outlines_core==0.1.25", ] dynamic = ["version"] [project.optional-dependencies] +vllm = ["vllm", "transformers", "numpy<2"] +transformers = ["transformers", "accelerate", "datasets", "numpy<2"] +mlxlm = ["mlx-lm", "datasets"] +openai = ["openai"] +llamacpp = ["llama-cpp-python", "transformers", "datasets", "numpy<2"] +exllamav2 = ["exllamav2"] test = [ "pre-commit", "pytest", @@ -58,13 +63,15 @@ test = [ "beartype<0.16.0", "responses", "llama-cpp-python", - "mlx-lm; platform_machine == 'arm64' and sys_platform == 'darwin'", + "mlx-lm>=0.19.2; platform_machine == 'arm64' and sys_platform == 'darwin'", "huggingface_hub", "openai>=1.0.0", + "datasets", "vllm; sys_platform != 'darwin'", "transformers", "pillow", "exllamav2", + "jax" ] serve = [ "vllm>=0.3.0", @@ -109,6 +116,9 @@ enable_incomplete_feature = ["Unpack"] [[tool.mypy.overrides]] module = [ "exllamav2.*", + "jax", + "jaxlib", + "jax.numpy", "jinja2", "jsonschema.*", "openai.*", diff --git a/tests/fsm/test_guide.py b/tests/fsm/test_guide.py index 510faf4b0..bf25c43c4 100644 --- a/tests/fsm/test_guide.py +++ b/tests/fsm/test_guide.py @@ -59,7 +59,7 @@ def convert_token_to_string(self, token): tokenizer = MockTokenizer() fsm = RegexGuide.from_regex(regex_str, tokenizer) - assert fsm.states_to_token_maps == {0: {1: 1}} + assert fsm.states_to_token_maps.get_transitions() == {0: {1: 1}} instruction = fsm.get_next_instruction(0) assert isinstance(instruction, Generate) @@ -70,9 +70,6 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(0) is False - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - def test_regex_multi_byte_llama_like(): class MockTokenizer: @@ -100,7 +97,7 @@ def convert_token_to_string(self, token): tokenizer = MockTokenizer() fsm = RegexGuide.from_regex(regex_str, tokenizer) - assert fsm.states_to_token_maps == { + assert fsm.states_to_token_maps.get_transitions() == { 0: {5: 1, 4: 2}, 1: {6: 3}, 3: {7: 4}, @@ -116,9 +113,6 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(0) is False - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - def test_regex_multi_byte_gpt2_like(): class MockTokenizer: @@ -147,7 +141,7 @@ def convert_token_to_string(self, token): tokenizer = MockTokenizer() fsm = RegexGuide.from_regex(regex_str, tokenizer) - assert fsm.states_to_token_maps == { + assert fsm.states_to_token_maps.get_transitions() == { 0: {5: 1, 10: 2}, 1: {8: 5, 4: 3}, 2: {11: 3}, @@ -163,9 +157,6 @@ def convert_token_to_string(self, token): assert fsm.is_final_state(0) is False - for state in fsm.final_states: - assert fsm.is_final_state(state) is True - def test_regex_final_state(): """Make sure that the FSM stays in the final state as we keep generating""" diff --git a/tests/fsm/test_json_schema.py b/tests/fsm/test_json_schema.py index 7565ff642..23864e029 100644 --- a/tests/fsm/test_json_schema.py +++ b/tests/fsm/test_json_schema.py @@ -1,27 +1,14 @@ import json -import re -from typing import List, Literal, Union +from contextlib import nullcontext +from enum import Enum +from functools import partial +from typing import List -import interegular import pytest -from pydantic import BaseModel, Field, constr +from outlines_core.fsm.json_schema import build_regex_from_schema +from pydantic import BaseModel, constr -from outlines.fsm.json_schema import ( - BOOLEAN, - DATE, - DATE_TIME, - INTEGER, - NULL, - NUMBER, - STRING, - STRING_INNER, - TIME, - UUID, - WHITESPACE, - build_regex_from_schema, - get_schema_from_signature, - to_regex, -) +from outlines.fsm.json_schema import get_schema_from_enum, get_schema_from_signature def test_function_basic(): @@ -54,988 +41,38 @@ class User(BaseModel): is_true: bool schema = json.dumps(User.model_json_schema()) - schedule = build_regex_from_schema(schema) - assert isinstance(schedule, str) - - -@pytest.mark.parametrize( - "pattern,does_match", - [ - ({"integer": "0"}, True), - ({"integer": "1"}, True), - ({"integer": "-1"}, True), - ({"integer": "01"}, False), - ({"integer": "1.3"}, False), - ({"integer": "t"}, False), - ], -) -def test_match_integer(pattern, does_match): - step = {"title": "Foo", "type": "integer"} - regex = to_regex(None, step) - assert regex == INTEGER - - value = pattern["integer"] - match = re.fullmatch(regex, value) - if does_match: - assert match[0] == value - assert match.span() == (0, len(value)) - else: - assert match is None + regex_str = build_regex_from_schema(schema) + assert isinstance(regex_str, str) -@pytest.mark.parametrize( - "pattern,does_match", - [ - ({"number": "1"}, True), - ({"number": "0"}, True), - ({"number": "01"}, False), - ({"number": ".3"}, False), - ({"number": "1.3"}, True), - ({"number": "-1.3"}, True), - ({"number": "1.3e9"}, False), - ({"number": "1.3e+9"}, True), - ], -) -def test_match_number(pattern, does_match): - step = {"title": "Foo", "type": "number"} - regex = to_regex(None, step) - assert regex == NUMBER +def add(a: float, b: float) -> float: + return a + b - value = pattern["number"] - match = re.fullmatch(regex, value) - if does_match: - assert match[0] == value - assert match.span() == (0, len(value)) - else: - assert match is None +class MyEnum(Enum): + add = partial(add) + a = "a" + b = 2 -@pytest.mark.parametrize( - "schema,regex,examples", - [ - # String - ( - {"title": "Foo", "type": "string"}, - STRING, - [ - ("unquotedstring", False), - ('"(parenthesized_string)"', True), - ('"malformed) parenthesis (((() string"', True), - ('"quoted_string"', True), - (r'"escape_\character"', False), - (r'"double_\\escape"', True), - (r'"\n"', False), - (r'"\\n"', True), - (r'"unescaped " quote"', False), - (r'"escaped \" quote"', True), - ], - ), - # String with maximum length - ( - {"title": "Foo", "type": "string", "maxLength": 3}, - f'"{STRING_INNER}{{,3}}"', - [('"ab"', True), ('"a""', False), ('"abcd"', False)], - ), - # String with minimum length - ( - {"title": "Foo", "type": "string", "minLength": 3}, - f'"{STRING_INNER}{{3,}}"', - [('"ab"', False), ('"abcd"', True), ('"abc""', False)], - ), - # String with both minimum and maximum length - ( - {"title": "Foo", "type": "string", "minLength": 3, "maxLength": 5}, - f'"{STRING_INNER}{{3,5}}"', - [('"ab"', False), ('"abcd"', True), ('"abcdef""', False)], - ), - # String defined by a regular expression - ( - {"title": "Foo", "type": "string", "pattern": r"^[a-z]$"}, - r'("[a-z]")', - [('"a"', True), ('"1"', False)], - ), - # Boolean - ( - {"title": "Foo", "type": "boolean"}, - BOOLEAN, - [ - ("true", True), - ("false", True), - ("null", False), - ("0", False), - ], - ), - # Null - ( - {"title": "Foo", "type": "null"}, - NULL, - [ - ("null", True), - ("true", False), - ("0", False), - ], - ), - # Const string - ( - {"title": "Foo", "const": "Marc", "type": "string"}, - '"Marc"', - [('"Marc"', True), ('"Jean"', False), ('"John"', False)], - ), - # Make sure strings are escaped with regex escaping - ( - {"title": "Foo", "const": ".*", "type": "string"}, - r'"\.\*"', - [('".*"', True), (r'"\s*"', False), (r'"\.\*"', False)], - ), - # Make sure strings are escaped with JSON escaping - ( - {"title": "Foo", "const": '"', "type": "string"}, - r'"\\""', - [('"\\""', True), ('"""', False)], - ), - # Const integer - ( - {"title": "Foo", "const": 0, "type": "integer"}, - "0", - [("0", True), ("1", False), ("a", False)], - ), - # Const float - ( - {"title": "Foo", "const": 0.2, "type": "float"}, - r"0\.2", - [("0.2", True), ("032", False)], - ), - # Const boolean - ( - {"title": "Foo", "const": True, "type": "boolean"}, - "true", - [("true", True), ("True", False)], - ), - # Const null - ( - {"title": "Foo", "const": None, "type": "null"}, - "null", - [("null", True), ("None", False), ("", False)], - ), - # Enum string - ( - {"title": "Foo", "enum": ["Marc", "Jean"], "type": "string"}, - '("Marc"|"Jean")', - [('"Marc"', True), ('"Jean"', True), ('"John"', False)], - ), - # Make sure strings are escaped with regex and JSON escaping - ( - {"title": "Foo", "enum": [".*", r"\s*"], "type": "string"}, - r'("\.\*"|"\\\\s\*")', - [('".*"', True), (r'"\\s*"', True), (r'"\.\*"', False)], - ), - # Enum integer - ( - {"title": "Foo", "enum": [0, 1], "type": "integer"}, - "(0|1)", - [("0", True), ("1", True), ("a", False)], - ), - # Enum mix of types - ( - {"title": "Foo", "enum": [6, 5.3, "potato", True, None]}, - r'(6|5\.3|"potato"|true|null)', - [ - ("6", True), - ("5.3", True), - ('"potato"', True), - ("true", True), - ("null", True), - ("523", False), - ("True", False), - ("None", False), - ], - ), - # integer - ( - { - "title": "Foo", - "type": "object", - "properties": {"count": {"title": "Count", "type": "integer"}}, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]*)[ ]?\\}', - [('{ "count": 100 }', True)], - ), - # integer with minimum digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": {"title": "Count", "type": "integer", "minDigits": 3} - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,})[ ]?\\}', - [('{ "count": 10 }', False), ('{ "count": 100 }', True)], - ), - # integer with maximum digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": {"title": "Count", "type": "integer", "maxDigits": 3} - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{,2})[ ]?\\}', - [('{ "count": 100 }', True), ('{ "count": 1000 }', False)], - ), - # integer with minimum and maximum digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "integer", - "minDigits": 3, - "maxDigits": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?(-)?(0|[1-9][0-9]{2,4})[ ]?\\}', - [ - ('{ "count": 10 }', False), - ('{ "count": 100 }', True), - ('{ "count": 10000 }', True), - ('{ "count": 100000 }', False), - ], - ), - # number - ( - { - "title": "Foo", - "type": "object", - "properties": {"count": {"title": "Count", "type": "number"}}, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', - [('{ "count": 100 }', True), ('{ "count": 100.5 }', True)], - ), - # number with min and max integer digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsInteger": 3, - "maxDigitsInteger": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]+)?([eE][+-][0-9]+)?[ ]?\\}', - [ - ('{ "count": 10.005 }', False), - ('{ "count": 100.005 }', True), - ('{ "count": 10000.005 }', True), - ('{ "count": 100000.005 }', False), - ], - ), - # number with min and max fraction digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsFraction": 3, - "maxDigitsFraction": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]{3,5})?([eE][+-][0-9]+)?[ ]?\\}', - [ - ('{ "count": 1.05 }', False), - ('{ "count": 1.005 }', True), - ('{ "count": 1.00005 }', True), - ('{ "count": 1.000005 }', False), - ], - ), - # number with min and max exponent digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsExponent": 3, - "maxDigitsExponent": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]*))(\\.[0-9]+)?([eE][+-][0-9]{3,5})?[ ]?\\}', - [ - ('{ "count": 1.05e1 }', False), - ('{ "count": 1.05e+001 }', True), - ('{ "count": 1.05e-00001 }', True), - ('{ "count": 1.05e0000001 }', False), - ], - ), - # number with min and max integer, fraction and exponent digits - ( - { - "title": "Foo", - "type": "object", - "properties": { - "count": { - "title": "Count", - "type": "number", - "minDigitsInteger": 3, - "maxDigitsInteger": 5, - "minDigitsFraction": 3, - "maxDigitsFraction": 5, - "minDigitsExponent": 3, - "maxDigitsExponent": 5, - } - }, - "required": ["count"], - }, - '\\{[ ]?"count"[ ]?:[ ]?((-)?(0|[1-9][0-9]{2,4}))(\\.[0-9]{3,5})?([eE][+-][0-9]{3,5})?[ ]?\\}', - [ - ('{ "count": 1.05e1 }', False), - ('{ "count": 100.005e+001 }', True), - ('{ "count": 10000.00005e-00001 }', True), - ('{ "count": 100000.000005e0000001 }', False), - ], - ), - # array - ( - {"title": "Foo", "type": "array", "items": {"type": "number"}}, - rf"\[{WHITESPACE}(({NUMBER})(,{WHITESPACE}({NUMBER})){{0,}})?{WHITESPACE}\]", - [("[1e+9,1.3]", True), ("[]", True), ("[1", False)], - ), - # array with a set length of 1 - ( - { - "title": "Foo", - "type": "array", - "items": {"type": "integer"}, - "minItems": 1, - "maxItems": 1, - }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{0,0}}){WHITESPACE}\]", - [("[1]", True), ("[1,2]", False), ('["a"]', False), ("[]", False)], - ), - # array with a set length greather than 1 - ( - { - "title": "Foo", - "type": "array", - "items": {"type": "integer"}, - "minItems": 3, - "maxItems": 3, - }, - rf"\[{WHITESPACE}(({INTEGER})(,{WHITESPACE}({INTEGER})){{2,2}}){WHITESPACE}\]", - [("[1]", False), ("[]", False), ("[1,2,3]", True), ("[1,2,3,4]", False)], - ), - # array with length 0 - ( - { - "title": "Foo", - "type": "array", - "items": {"type": "integer"}, - "minItems": 0, - "maxItems": 0, - }, - rf"\[{WHITESPACE}\]", - [("[1]", False), ("[]", True), ("[1,2,3]", False), ("[1,2,3,4]", False)], - ), - # object - ( - { - "title": "TestSchema", - "type": "object", - "properties": { - "test_dict": { - "title": "Test Dict", - "additionalProperties": {"type": "string"}, - "type": "object", - } - }, - "required": ["test_dict"], - }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{STRING}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{STRING}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", - [ - ("""{ "test_dict":{"foo":"bar","baz": "bif"}}""", True), - ("""{ "test_dict":{"foo":"bar" }}""", True), - ("""{ "test_dict":{}}""", True), - ("""{ "WRONG_KEY":{}}""", False), - ("""{ "test_dict":{"wrong_type" 1}}""", False), - ], - ), - # object containing object - ( - { - "title": "TestSchema", - "type": "object", - "properties": { - "test_dict": { - "title": "Test Dict", - "additionalProperties": { - "additionalProperties": {"type": "integer"}, - "type": "object", - }, - "type": "object", - } - }, - "required": ["test_dict"], - }, - rf"""\{{{WHITESPACE}"test_dict"{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}\{{{WHITESPACE}({STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}({WHITESPACE},{WHITESPACE}{STRING}{WHITESPACE}:{WHITESPACE}{INTEGER}){{0,}})?{WHITESPACE}\}}){{0,}})?{WHITESPACE}\}}{WHITESPACE}\}}""", - [ - ( - """{"test_dict": {"foo": {"bar": 123, "apple": 99}, "baz": {"bif": 456}}}""", - True, - ), - ( - """{"test_dict": {"anykey": {"anykey": 123}, "anykey2": {"bif": 456}}}""", - True, - ), - ("""{"test_dict": {}}""", True), - ("""{"test_dict": {"dict of empty dicts are ok": {} }}""", True), - ( - """{"test_dict": {"anykey": {"ONLY Dict[Dict]": 123}, "No Dict[int]" 1: }}""", - False, - ), - ], - ), - # oneOf - ( - { - "title": "Foo", - "oneOf": [{"type": "string"}, {"type": "number"}, {"type": "boolean"}], - }, - rf'((?:"{STRING_INNER}*")|(?:{NUMBER})|(?:{BOOLEAN}))', - [ - ("12.3", True), - ("true", True), - ('"a"', True), - ("null", False), - ("", False), - ("12true", False), - ('1.3"a"', False), - ('12.3true"a"', False), - ], - ), - # anyOf - ( - { - "title": "Foo", - "anyOf": [{"type": "string"}, {"type": "integer"}], - }, - rf"({STRING}|{INTEGER})", - [("12", True), ('"a"', True), ('1"a"', False)], - ), - # allOf - ( - { - "title": "Foo", - "allOf": [{"type": "string"}, {"type": "integer"}], - }, - rf"({STRING}{INTEGER})", - [('"a"1', True), ('"a"', False), ('"1"', False)], - ), - # Tuple / prefixItems - ( - { - "title": "Foo", - "prefixItems": [{"type": "string"}, {"type": "integer"}], - }, - rf"\[{WHITESPACE}{STRING}{WHITESPACE},{WHITESPACE}{INTEGER}{WHITESPACE}\]", - [('["a", 1]', True), ('["a", 1, 1]', False), ("[]", False)], - ), - # Nested schema - ( - { - "title": "Bar", - "type": "object", - "properties": { - "fuzz": { - "title": "Foo", - "type": "object", - "properties": {"spam": {"title": "Spam", "type": "integer"}}, - "required": ["spam"], - } - }, - "required": ["fuzz"], - }, - f'\\{{[ ]?"fuzz"[ ]?:[ ]?\\{{[ ]?"spam"[ ]?:[ ]?{INTEGER}[ ]?\\}}[ ]?\\}}', - [('{ "fuzz": { "spam": 100 }}', True)], - ), - # Schema with a reference - ( - { - "title": "User", - "type": "object", - "properties": { - "user_id": {"title": "User Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "a": {"$ref": "#/properties/name"}, - }, - "required": ["user_id", "name", "a"], - }, - f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"a"[ ]?:[ ]?{STRING}[ ]?\\}}', - [('{"user_id": 100, "name": "John", "a": "Marc"}', True)], - ), - ( - { - "title": "User", - "type": "object", - "$defs": {"name": {"title": "Name2", "type": "string"}}, - "properties": { - "user_id": {"title": "User Id", "type": "integer"}, - "name": {"title": "Name", "type": "string"}, - "name2": {"$ref": "#/$defs/name"}, - }, - "required": ["user_id", "name", "name2"], - }, - f'\\{{[ ]?"user_id"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"name2"[ ]?:[ ]?{STRING}[ ]?\\}}', - [('{"user_id": 100, "name": "John", "name2": "Marc"}', True)], - ), - ( - { - "$id": "customer", - "$schema": "https://json-schema.org/draft/2020-12/schema", - "title": "Customer", - "type": "object", - "properties": { - "name": {"type": "string"}, - "last_name": {"type": "string"}, - "address": {"$ref": "customer#/$defs/address"}, - }, - "required": [ - "name", - "first_name", - "last_name", - "address", - "shipping_address", - "billing_address", - ], - "$defs": { - "address": { - "title": "Address", - "$schema": "http://json-schema.org/draft-07/schema#", - "type": "object", - "properties": { - "city": {"type": "string"}, - }, - "required": ["street_address", "city", "state"], - "definitions": { - "state": { - "type": "object", - "title": "State", - "properties": {"name": {"type": "string"}}, - "required": ["name"], - } - }, - } - }, - }, - f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"last_name"[ ]?:[ ]?{STRING}[ ]?,[ ]?"address"[ ]?:[ ]?\\{{[ ]?"city"[ ]?:[ ]?{STRING}[ ]?\\}}[ ]?\\}}', - [ - ( - '{"name": "John", "last_name": "Doe", "address": {"city": "Paris"}}', - True, - ) - ], - ), - # Optional properties - # Last required property in first position - ( - { - "properties": { - "name": {"type": "string"}, - "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "weapon": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - }, - "required": ["name"], - "title": "Character", - "type": "object", - }, - f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"weapon"[ ]?:[ ]?({STRING}|null))?[ ]?\\}}', - [ - ('{ "name" : "Player" }', True), - ('{ "name" : "Player", "weapon" : "sword" }', True), - ('{ "age" : 10, "weapon" : "sword" }', False), - ], - ), - # Last required property in middle position - ( - { - "properties": { - "name": {"type": "string"}, - "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "weapon": {"type": "string"}, - "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - }, - "required": ["name", "weapon"], - "title": "Character", - "type": "object", - }, - f'\\{{[ ]?"name"[ ]?:[ ]?{STRING}[ ]?,([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', - [ - ('{ "name" : "Player" , "weapon" : "sword" }', True), - ( - '{ "name" : "Player", "age" : 10, "weapon" : "sword" , "strength" : 10 }', - True, - ), - ('{ "weapon" : "sword" }', False), - ], - ), - # Last required property in last position - ( - { - "properties": { - "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - "age": {"type": "integer"}, - "armor": {"type": "string"}, - "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "weapon": {"title": "Weapon", "type": "string"}, - }, - "required": ["age", "armor", "weapon"], - "title": "Character", - "type": "object", - }, - f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?{INTEGER}[ ]?,[ ]?"armor"[ ]?:[ ]?{STRING}[ ]?,([ ]?"strength"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"weapon"[ ]?:[ ]?{STRING}[ ]?\\}}', - [ - ( - '{ "name" : "Player", "age" : 10, "armor" : "plate", "strength" : 11, "weapon" : "sword" }', - True, - ), - ('{ "age" : 10, "armor" : "plate", "weapon" : "sword" }', True), - ( - '{ "name" : "Kahlhanbeh", "armor" : "plate", "weapon" : "sword" }', - False, - ), - ], - ), - # All properties are optional - ( - { - "properties": { - "name": {"anyOf": [{"type": "string"}, {"type": "null"}]}, - "age": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - "strength": {"anyOf": [{"type": "integer"}, {"type": "null"}]}, - }, - "title": "Character", - "type": "object", - }, - f'\\{{([ ]?"name"[ ]?:[ ]?({STRING}|null)([ ]?,[ ]?"age"[ ]?:[ ]?({INTEGER}|null))?([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?[ ]?"age"[ ]?:[ ]?({INTEGER}|null)([ ]?,[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?|([ ]?"name"[ ]?:[ ]?({STRING}|null)[ ]?,)?([ ]?"age"[ ]?:[ ]?({INTEGER}|null)[ ]?,)?[ ]?"strength"[ ]?:[ ]?({INTEGER}|null))?[ ]?\\}}', - [ - ('{ "name" : "Player" }', True), - ('{ "name" : "Player", "age" : 10, "strength" : 10 }', True), - ('{ "age" : 10, "strength" : 10 }', True), - ("{ }", True), - ], - ), - ], -) -def test_match(schema, regex, examples): - interegular.parse_pattern(regex) - schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - assert test_regex == regex - for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - if match is None: - raise ValueError(f"Expected match for '{string}'") - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None +# if you don't register your function as callable, you will get an empty enum +class EmptyEnum(Enum): + add = add @pytest.mark.parametrize( - "schema,regex,examples", + "enum,expectation", [ - # UUID - ( - {"title": "Foo", "type": "string", "format": "uuid"}, - UUID, - [ - ("123e4567-e89b-12d3-a456-426614174000", False), - ('"123e4567-e89b-12d3-a456-426614174000"', True), - ('"123e4567-e89b-12d3-a456-42661417400"', False), - ('"123e4567-e89b-12d3-a456-42661417400g"', False), - ('"123e4567-e89b-12d3-a456-42661417400-"', False), - ('""', False), - ], - ), - # DATE-TIME - ( - {"title": "Foo", "type": "string", "format": "date-time"}, - DATE_TIME, - [ - ("2018-11-13T20:20:39Z", False), - ('"2018-11-13T20:20:39Z"', True), - ('"2016-09-18T17:34:02.666Z"', True), - ('"2008-05-11T15:30:00Z"', True), - ('"2021-01-01T00:00:00"', True), - ('"2022-01-10 07:19:30"', False), # missing T - ('"2022-12-10T10-04-29"', False), # incorrect separator - ('"2023-01-01"', False), - ], - ), - # DATE - ( - {"title": "Foo", "type": "string", "format": "date"}, - DATE, - [ - ("2018-11-13", False), - ('"2018-11-13"', True), - ('"2016-09-18"', True), - ('"2008-05-11"', True), - ('"2015-13-01"', False), # incorrect month - ('"2022-01"', False), # missing day - ('"2022/12/01"', False), # incorrect separator" - ], - ), - # TIME - ( - {"title": "Foo", "type": "string", "format": "time"}, - TIME, - [ - ("20:20:39Z", False), - ('"20:20:39Z"', True), - ('"15:30:00Z"', True), - ('"25:30:00"', False), # incorrect hour - ('"15:30"', False), # missing seconds - ('"15:30:00.000"', False), # missing Z - ('"15-30-00"', False), # incorrect separator - ('"15:30:00+01:00"', False), # incorrect separator - ], - ), + (MyEnum, nullcontext()), + (EmptyEnum, pytest.raises(ValueError)), ], ) -def test_format(schema, regex, examples): - interegular.parse_pattern(regex) - schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - assert test_regex == regex - - for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None - - -@pytest.mark.parametrize( - "schema,examples", - [ - # NESTED UUID - ( - { - "title": "Foo", - "type": "object", - "properties": {"uuid": {"type": "string", "format": "uuid"}}, - }, - [ - ('{"uuid": "123e4567-e89b-12d3-a456-426614174000"}', True), - ('{"uuid":"123e4567-e89b-12d3-a456-42661417400"}', False), - ('{"uuid":"123e4567-e89b-12d3-a456-42661417400g"}', False), - ('{"uuid":"123e4567-e89b-12d3-a456-42661417400-"}', False), - ( - '{"uuid":123e4567-e89b-12d3-a456-426614174000}', - False, - ), # missing quotes for value - ('{"uuid":""}', False), - ], - ), - # NESTED DATE-TIME - ( - { - "title": "Foo", - "type": "object", - "properties": {"dateTime": {"type": "string", "format": "date-time"}}, - }, - [ - ('{"dateTime": "2018-11-13T20:20:39Z"}', True), - ('{"dateTime":"2016-09-18T17:34:02.666Z"}', True), - ('{"dateTime":"2008-05-11T15:30:00Z"}', True), - ('{"dateTime":"2021-01-01T00:00:00"}', True), - ('{"dateTime":"2022-01-10 07:19:30"}', False), # missing T - ('{"dateTime":"2022-12-10T10-04-29"}', False), # incorrect separator - ( - '{"dateTime":2018-11-13T20:20:39Z}', - False, - ), # missing quotes for value - ('{"dateTime":"2023-01-01"}', False), - ], - ), - # NESTED DATE - ( - { - "title": "Foo", - "type": "object", - "properties": {"date": {"type": "string", "format": "date"}}, - }, - [ - ('{"date": "2018-11-13"}', True), - ('{"date":"2016-09-18"}', True), - ('{"date":"2008-05-11"}', True), - ('{"date":"2015-13-01"}', False), # incorrect month - ('{"date":"2022-01"}', False), # missing day - ('{"date":"2022/12/01"}', False), # incorrect separator" - ('{"date":2018-11-13}', False), # missing quotes for value - ], - ), - # NESTED TIME - ( - { - "title": "Foo", - "type": "object", - "properties": {"time": {"type": "string", "format": "time"}}, - }, - [ - ('{"time": "20:20:39Z"}', True), - ('{"time":"15:30:00Z"}', True), - ('{"time":"25:30:00"}', False), # incorrect hour - ('{"time":"15:30"}', False), # missing seconds - ('{"time":"15:30:00.000"}', False), # missing Z - ('{"time":"15-30-00"}', False), # incorrect separator - ('{"time":"15:30:00+01:00"}', False), # incorrect separator - ('{"time":20:20:39Z}', False), # missing quotes for value - ], - ), - # Unconstrained Object - ( - { - "title": "Foo", - "type": "object", - }, - [ - ("{}", True), - ('{"a": 1, "b": null}', True), - ('{"a": {"z": {"g": 4}}, "b": null}', True), - ("1234", False), # not an object - ('["a", "a"]', False), # not an array - ], - ), - # Unconstrained Array - ( - { - "type": "array", - }, - [ - ("[1, {}, false]", True), - ("[{}]", True), - ('[{"a": {"z": "q"}, "b": null}]', True), - ('[{"a": [1, 2, true], "b": null}]', True), - ('[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2]]]', True), - # too deep, default unconstrained depth limit = 2 - ( - '[{"a": [1, 2, true], "b": {"a": "b"}}, 1, true, [1, [2, [3]]]]', - False, - ), - ('[{"a": {"z": {"g": 4}}, "b": null}]', False), - ("[[[[1]]]]", False), - # not an array - ("{}", False), - ('{"a": 1, "b": null}', False), - ('{"a": {"z": {"g": 4}}, "b": null}', False), - ("1234", False), # not an array - ('{"a": "a"}', False), # not an array - ], - ), - # No schema / unconstrained value - ( - {}, - [ - ('"aaabbuecuh"', True), # string - ("5.554", True), # number - ("true", True), # boolean - ("null", True), # null - ("5999", True), # integer - ('["a", "b"]', True), # array - ('{"key": {"k2": "value"}}', True), # nested object - ("this isnt valid json", False), - ], - ), - ], -) -def test_format_without_regex(schema, examples): - schema = json.dumps(schema) - test_regex = build_regex_from_schema(schema) - for string, does_match in examples: - match = re.fullmatch(test_regex, string) - if does_match: - assert match[0] == string - assert match.span() == (0, len(string)) - else: - assert match is None - - -@pytest.mark.parametrize("whitespace_pattern", [None, r"[\n ]*", "abc"]) -def test_json_schema_custom_whitespace_pattern(whitespace_pattern): - """assert whitespace_pattern setting respected""" - - class MockModel(BaseModel): - foo: int - bar: str - - schema = json.dumps(MockModel.model_json_schema()) - - # assert any ws pattern can be used - if whitespace_pattern == "abc": - build_regex_from_schema(schema, whitespace_pattern) - return - - pattern = build_regex_from_schema(schema, whitespace_pattern) - - mock_result_mult_ws = ( - """{ "foo" : 4, \n\n\n "bar": "baz baz baz bar"\n\n}""" - ) - mock_result_maybe_ws = """{"foo" : 4 ,"bar":"baz baz baz bar"}""" - - match_default_ws = re.fullmatch(pattern, mock_result_maybe_ws) - if whitespace_pattern is None: - assert match_default_ws - else: - assert re.fullmatch(pattern, mock_result_mult_ws) - - -def test_one_of_doesnt_produce_illegal_lookaround(): - """Reproduces failure in https://github.com/dottxt-ai/outlines/issues/823""" - - class Cat(BaseModel): - pet_type: Literal["cat"] - meows: int - - class Dog(BaseModel): - pet_type: Literal["dog"] - barks: float - - class Model(BaseModel): - pet: Union[Cat, Dog] = Field(..., discriminator="pet_type") - n: int - - json_schema = Model.schema_json() - - json_schema = Model.schema_json() - pattern = build_regex_from_schema(json_schema, whitespace_pattern=None) - - # check if the pattern uses lookarounds incompatible with interegular.Pattern.to_fsm() - interegular.parse_pattern(pattern).to_fsm() +def test_enum_schema(enum, expectation): + with expectation: + schema = get_schema_from_enum(enum) + regex_str = build_regex_from_schema(json.dumps(schema)) + assert isinstance(regex_str, str) + assert schema["title"] == enum.__name__ + assert len(schema["oneOf"]) == len(enum) + for elt in schema["oneOf"]: + assert type(elt) in [int, float, bool, type(None), str, dict] diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 9c288c21e..f91bc8653 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -1,5 +1,6 @@ import contextlib import re +from enum import Enum import pytest @@ -127,6 +128,18 @@ def model_t5(tmp_path_factory): ) +class MyEnum(Enum): + foo = "foo" + bar = "bar" + baz = "baz" + + +ALL_SAMPLE_CHOICES_FIXTURES = ( + ["foo", "bar", "baz"], + MyEnum, +) + + ########################################## # Stuctured Generation Inputs ########################################## @@ -264,21 +277,33 @@ def test_generate_json(request, model_fixture, sample_schema): @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES) def test_generate_choice(request, model_fixture, sample_choices): model = request.getfixturevalue(model_fixture) generator = generate.choice(model, sample_choices) res = generator(**get_inputs(model_fixture)) - assert res in sample_choices + if isinstance(sample_choices, type(Enum)): + assert res in [elt.value for elt in sample_choices] + else: + assert res in sample_choices @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES) def test_generate_choice_twice(request, model_fixture, sample_choices): model = request.getfixturevalue(model_fixture) generator = generate.choice(model, sample_choices) res = generator(**get_inputs(model_fixture)) - assert res in sample_choices + if isinstance(sample_choices, type(Enum)): + assert res in [elt.value for elt in sample_choices] + else: + assert res in sample_choices + res = generator(**get_inputs(model_fixture)) - assert res in sample_choices + if isinstance(sample_choices, type(Enum)): + assert res in [elt.value for elt in sample_choices] + else: + assert res in sample_choices @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) diff --git a/tests/generate/test_integration_llamacpp.py b/tests/generate/test_integration_llamacpp.py index 8d4596d60..fd5be2171 100644 --- a/tests/generate/test_integration_llamacpp.py +++ b/tests/generate/test_integration_llamacpp.py @@ -274,6 +274,7 @@ def test_llama_cpp_pre_tokenizer_remains_broken(): generate.choice(model, ["skirt", "dress", "pen", "jacket"]) +@pytest.mark.skip("Caching for guide was temporarily turned off") def test_RegexGuide_caching(model, temp_cache_dir): import llama_cpp diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 2462d9fcf..92c5d789c 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -1,6 +1,7 @@ import datetime import re from enum import Enum +from functools import partial from typing import List, Union import pytest @@ -354,6 +355,29 @@ class User(BaseModel): assert result.user_id in [1, 2] +def add(a: int, b: int) -> int: + return a + b + + +def mul(c: float, d: float) -> float: + return c * d + + +def test_transformers_json_function_enum(model): + prompt = "Output some JSON " + + class Operation(Enum): + add = partial(add) + mul = partial(mul) + + result = generate.json(model, Operation)(prompt, seed=0) + assert isinstance(result, dict) + assert len(result) == 2 + for k, v in result.items(): + assert k in ["a", "b", "c", "d"] + assert isinstance(v, (int, float)) + + def test_transformers_json_array(model): prompt = "Output some JSON " @@ -492,6 +516,7 @@ def test_transformers_use_existing_model_and_tokenizer(): assert isinstance(sequence, str) +@pytest.mark.skip("Caching for guide was temporarily turned off") def test_RegexGuide_caching(temp_cache_dir): import outlines.caching from outlines.fsm.guide import cached_create_states_mapping diff --git a/tests/models/test_mlxlm.py b/tests/models/test_mlxlm.py new file mode 100644 index 000000000..20e59da81 --- /dev/null +++ b/tests/models/test_mlxlm.py @@ -0,0 +1,100 @@ +import pytest + +from outlines.models.mlxlm import mlxlm +from outlines.models.transformers import TransformerTokenizer + +try: + import mlx.core as mx + + HAS_MLX = mx.metal.is_available() +except ImportError: + HAS_MLX = False + + +TEST_MODEL = "mlx-community/SmolLM-135M-Instruct-4bit" + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_model(): + model = mlxlm(TEST_MODEL) + assert hasattr(model, "model") + assert hasattr(model, "tokenizer") + assert isinstance(model.tokenizer, TransformerTokenizer) + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_tokenizer(): + model = mlxlm(TEST_MODEL) + + # Test single string encoding/decoding + test_text = "Hello, world!" + token_ids = mx.array(model.mlx_tokenizer.encode(test_text)) + assert isinstance(token_ids, mx.array) + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_generate(): + from outlines.generate.api import GenerationParameters, SamplingParameters + + model = mlxlm(TEST_MODEL) + prompt = "Write a haiku about programming:" + + # Test with basic generation parameters + gen_params = GenerationParameters(max_tokens=50, stop_at=None, seed=None) + + # Test with different sampling parameters + sampling_params = SamplingParameters( + sampler="multinomial", num_samples=1, top_p=0.9, top_k=None, temperature=0.7 + ) + + # Test generation + output = model.generate(prompt, gen_params, None, sampling_params) + assert isinstance(output, str) + assert len(output) > 0 + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_stream(): + from outlines.generate.api import GenerationParameters, SamplingParameters + + model = mlxlm(TEST_MODEL) + prompt = "Count from 1 to 5:" + + gen_params = GenerationParameters(max_tokens=20, stop_at=None, seed=None) + + sampling_params = SamplingParameters( + sampler="greedy", # Use greedy sampling for deterministic output + num_samples=1, + top_p=None, + top_k=None, + temperature=0.0, + ) + + # Test streaming + stream = model.stream(prompt, gen_params, None, sampling_params) + tokens = list(stream) + assert len(tokens) > 0 + assert all(isinstance(token, str) for token in tokens) + + # Test that concatenated streaming output matches generate output + streamed_text = "".join(tokens) + generated_text = model.generate(prompt, gen_params, None, sampling_params) + assert streamed_text == generated_text + + +@pytest.mark.skipif(not HAS_MLX, reason="MLX tests require Apple Silicon") +def test_mlxlm_errors(): + model = mlxlm(TEST_MODEL) + + # Test batch inference (should raise NotImplementedError) + with pytest.raises(NotImplementedError): + from outlines.generate.api import GenerationParameters, SamplingParameters + + gen_params = GenerationParameters(max_tokens=10, stop_at=None, seed=None) + sampling_params = SamplingParameters("multinomial", 1, None, None, 1.0) + model.generate(["prompt1", "prompt2"], gen_params, None, sampling_params) + + # Test beam search (should raise NotImplementedError) + with pytest.raises(NotImplementedError): + sampling_params = SamplingParameters("beam_search", 1, None, None, 1.0) + model.generate("test prompt", gen_params, None, sampling_params) diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py new file mode 100644 index 000000000..cd9f48278 --- /dev/null +++ b/tests/processors/test_base_processor.py @@ -0,0 +1,74 @@ +from typing import List + +import jax.numpy as jnp +import numpy as np +import pytest +import torch + +from outlines.processors.base_logits_processor import OutlinesLogitsProcessor + +arrays = { + "list": [[1.0, 2.0], [3.0, 4.0]], + "np": np.array([[1, 2], [3, 4]], dtype=np.float32), + "jax": jnp.array([[1, 2], [3, 4]], dtype=jnp.float32), + "torch": torch.tensor([[1, 2], [3, 4]], dtype=torch.float32), +} + +try: + import mlx.core as mx + + arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32) +except ImportError: + pass + +try: + import jax.numpy as jnp + + arrays["jax"] = jnp.array([[1, 2], [3, 4]], dtype=jnp.float32) +except ImportError: + pass + + +# Mock implementation of the abstract class for testing +class MockLogitsProcessor(OutlinesLogitsProcessor): + def process_logits( + self, input_ids: List[List[int]], logits: torch.Tensor + ) -> torch.Tensor: + # For testing purposes, let's just return logits multiplied by 2 + return logits * 2 + + +@pytest.fixture +def processor(): + """Fixture for creating an instance of the MockLogitsProcessor.""" + return MockLogitsProcessor() + + +@pytest.mark.parametrize("array_type", arrays.keys()) +def test_to_torch(array_type, processor): + data = arrays[array_type] + torch_tensor = processor._to_torch(data) + assert isinstance(torch_tensor, torch.Tensor) + assert torch.allclose( + torch_tensor.cpu(), torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) + ) + + +@pytest.mark.parametrize("array_type", arrays.keys()) +def test_from_torch(array_type, processor): + torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) + data = processor._from_torch(torch_tensor, type(arrays[array_type])) + assert isinstance(data, type(arrays[array_type])) + assert np.allclose(data, arrays[array_type]) + + +@pytest.mark.parametrize("array_type", arrays.keys()) +def test_call(array_type, processor): + input_ids = arrays[array_type] + logits = arrays[array_type] + processed_logits = processor(input_ids, logits) + + assert isinstance(processed_logits, type(arrays[array_type])) + assert np.allclose( + np.array(processed_logits), np.array([[2.0, 4.0], [6.0, 8.0]], dtype=np.float32) + )