diff --git a/.gitignore b/.gitignore index 4984b18cb..08390ae3d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,7 @@ docs/build *.gguf .venv benchmarks/results +.python-version # Remove doc build folders .cache/ diff --git a/README.md b/README.md index e9285e863..d34b0984d 100644 --- a/README.md +++ b/README.md @@ -84,6 +84,29 @@ generator = outlines.generate.choice(model, ["Positive", "Negative"]) answer = generator(prompt) ``` +You can also pass these choices through en enum: + +````python +from enum import Enum + +import outlines + +class Sentiment(str, Enum): + positive = "Positive" + negative = "Negative" + +model = outlines.models.transformers("microsoft/Phi-3-mini-4k-instruct") + +prompt = """You are a sentiment-labelling assistant. +Is the following review positive or negative? + +Review: This restaurant is just awesome! +""" + +generator = outlines.generate.choice(model, Sentiment) +answer = generator(prompt) +```` + ### Type constraint You can instruct the model to only return integers or floats: diff --git a/outlines/generate/choice.py b/outlines/generate/choice.py index 595513d52..75fc71271 100644 --- a/outlines/generate/choice.py +++ b/outlines/generate/choice.py @@ -1,7 +1,10 @@ import json as pyjson +import re +from enum import Enum from functools import singledispatch -from typing import Callable, List +from typing import Callable, List, Union +from outlines.fsm.json_schema import build_regex_from_schema, get_schema_from_enum from outlines.generate.api import SequenceGeneratorAdapter from outlines.models import OpenAI from outlines.samplers import Sampler, multinomial @@ -12,12 +15,19 @@ @singledispatch def choice( - model, choices: List[str], sampler: Sampler = multinomial() + model, choices: Union[List[str], type[Enum]], sampler: Sampler = multinomial() ) -> SequenceGeneratorAdapter: - regex_str = r"(" + r"|".join(choices) + r")" + if isinstance(choices, type(Enum)): + regex_str = build_regex_from_schema(pyjson.dumps(get_schema_from_enum(choices))) + else: + choices = [re.escape(choice) for choice in choices] # type: ignore + regex_str = r"(" + r"|".join(choices) + r")" generator = regex(model, regex_str, sampler) - generator.format_sequence = lambda x: x + if isinstance(choices, type(Enum)): + generator.format_sequence = lambda x: pyjson.loads(x) + else: + generator.format_sequence = lambda x: x return generator diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index 9c288c21e..f91bc8653 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -1,5 +1,6 @@ import contextlib import re +from enum import Enum import pytest @@ -127,6 +128,18 @@ def model_t5(tmp_path_factory): ) +class MyEnum(Enum): + foo = "foo" + bar = "bar" + baz = "baz" + + +ALL_SAMPLE_CHOICES_FIXTURES = ( + ["foo", "bar", "baz"], + MyEnum, +) + + ########################################## # Stuctured Generation Inputs ########################################## @@ -264,21 +277,33 @@ def test_generate_json(request, model_fixture, sample_schema): @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES) def test_generate_choice(request, model_fixture, sample_choices): model = request.getfixturevalue(model_fixture) generator = generate.choice(model, sample_choices) res = generator(**get_inputs(model_fixture)) - assert res in sample_choices + if isinstance(sample_choices, type(Enum)): + assert res in [elt.value for elt in sample_choices] + else: + assert res in sample_choices @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +@pytest.mark.parametrize("sample_choices", ALL_SAMPLE_CHOICES_FIXTURES) def test_generate_choice_twice(request, model_fixture, sample_choices): model = request.getfixturevalue(model_fixture) generator = generate.choice(model, sample_choices) res = generator(**get_inputs(model_fixture)) - assert res in sample_choices + if isinstance(sample_choices, type(Enum)): + assert res in [elt.value for elt in sample_choices] + else: + assert res in sample_choices + res = generator(**get_inputs(model_fixture)) - assert res in sample_choices + if isinstance(sample_choices, type(Enum)): + assert res in [elt.value for elt in sample_choices] + else: + assert res in sample_choices @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)