From 8f40f50dfd6ed990b9a8910042375baf04e92bd5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 29 Nov 2024 18:46:02 +0100 Subject: [PATCH] Add `to_regex` method to the different types --- outlines/types/__init__.py | 34 ++++++++++++++++++++++++++++++++++ tests/test_types.py | 12 ++++++++++++ 2 files changed, 46 insertions(+) diff --git a/outlines/types/__init__.py b/outlines/types/__init__.py index 46f49f36c..8a79b8162 100644 --- a/outlines/types/__init__.py +++ b/outlines/types/__init__.py @@ -8,6 +8,8 @@ from pydantic import BaseModel, TypeAdapter from typing_extensions import _TypedDictMeta # type: ignore +from outlines.fsm.json_schema import build_regex_from_schema + from . import airports, countries from .email import Email from .isbn import ISBN @@ -30,6 +32,7 @@ class Json: """ definition: Union[str, dict] + whitespace_pattern: str = " " def to_json_schema(self): if isinstance(self.definition, str): @@ -52,11 +55,21 @@ def to_json_schema(self): return schema + def to_regex(self): + schema = self.to_json_schema() + schema_str = json.dumps(schema) + return build_regex_from_schema(schema_str, self.whitespace_pattern) + @dataclass class List: definition: list + def to_regex(self): + raise NotImplementedError( + "Structured generation for lists of objects are not implemented yet." + ) + @dataclass class Choice: @@ -67,3 +80,24 @@ class Choice: def __post_init__(self): if isinstance(self.definition, list): self.definition = Enum("Definition", [(x, x) for x in self.definition]) + + def to_list(self): + if isinstance(self.definition, list): + return self.definition + else: + return [x.value for x in self.definition] + + def to_regex(self): + choices = self.to_list() + regex_str = r"(" + r"|".join(choices) + r")" + return regex_str + + +@dataclass +class Regex: + """Represents a string defined by a regular expression.""" + + definition: str + + def to_regex(self): + return self.definition diff --git a/tests/test_types.py b/tests/test_types.py index d54ac1479..0fcaec140 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -52,6 +52,18 @@ def test_type_choice(): choice_type = types.Choice(choices) assert choice_type.definition.a.value == "a" + regex_str = choice_type.to_regex() + assert regex_str == "(a|b)" + + +def test_type_list(): + class Foo(BaseModel): + bar: int + + list_type = types.List(Foo) + with pytest.raises(NotImplementedError, match="Structured"): + list_type.to_regex() + @pytest.mark.parametrize( "custom_type,test_string,should_match",