From 0c5503a5a94e81a71de4a4c5d0b072c9a0020b01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Thu, 7 Dec 2023 17:48:25 +0100 Subject: [PATCH] Add a `Function` object The `Function` object encapsulates a prompt template, a (`transformers`) model name and an output structure. It can then simply be called with a prompt throughout an application. --- outlines/__init__.py | 1 + outlines/function.py | 50 ++++++++++++++++++++++++++++++++++++++++++ tests/test_function.py | 20 +++++++++++++++++ 3 files changed, 71 insertions(+) create mode 100644 outlines/function.py create mode 100644 tests/test_function.py diff --git a/outlines/__init__.py b/outlines/__init__.py index 7c8414af0..3d75d0aca 100644 --- a/outlines/__init__.py +++ b/outlines/__init__.py @@ -4,6 +4,7 @@ import outlines.text.generate from outlines.base import vectorize from outlines.caching import clear_cache, disable_cache, get_cache +from outlines.function import Function from outlines.prompts import prompt __all__ = [ diff --git a/outlines/function.py b/outlines/function.py new file mode 100644 index 000000000..308b2d322 --- /dev/null +++ b/outlines/function.py @@ -0,0 +1,50 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Callable, Optional, Union + +from outlines import generate, models + +if TYPE_CHECKING: + from outlines.generate.api import SequenceGenerator + + +@dataclass +class Function: + """Represents an Outlines function. + + Functions are a convenient way to encapsulate a prompt template, a language + model and a Pydantic model that define the output structure. Once defined, + the function can be called with arguments that will be used to render the + prompt template. + + """ + + prompt_template: Callable + model_name: str + schema: Union[str, Callable, object] + generator: Optional["SequenceGenerator"] = None + + def init_generator(self): + """Load the model and initialize the generator.""" + model = models.transformers(self.model_name) + self.generator = generate.json(model, self.schema) + + def __call__(self, *args, **kwargs): + """Call the function. + + .. warning:: + + This currently does not support batching. + + Parameters + ---------- + args + Values to pass to the prompt template as positional arguments. + kwargs + Values to pass to the prompt template as keyword arguments. + + """ + if self.generator is None: + self.init_generator() + + prompt = self.prompt_template(*args, **kwargs) + return self.generator(prompt) diff --git a/tests/test_function.py b/tests/test_function.py new file mode 100644 index 000000000..3a734cef8 --- /dev/null +++ b/tests/test_function.py @@ -0,0 +1,20 @@ +from pydantic import BaseModel + +import outlines +from outlines.function import Function + + +def test_function_basic(): + @outlines.prompt + def test_template(text: str): + """{{ text }}""" + + class Foo(BaseModel): + id: int + + fn = Function(test_template, "hf-internal-testing/tiny-random-GPTJForCausalLM", Foo) + + assert fn.generator is None + + result = fn("test") + assert isinstance(result, BaseModel)