Skip to content

Commit

Permalink
Add a Function object
Browse files Browse the repository at this point in the history
The `Function` object encapsulates a prompt template, a (`transformers`)
model name and an output structure. It can then simply be called with a
prompt throughout an application.
  • Loading branch information
rlouf committed Dec 13, 2023
1 parent add3743 commit 0c5503a
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 0 deletions.
1 change: 1 addition & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import outlines.text.generate
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.prompts import prompt

__all__ = [
Expand Down
50 changes: 50 additions & 0 deletions outlines/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Union

from outlines import generate, models

if TYPE_CHECKING:
from outlines.generate.api import SequenceGenerator


@dataclass
class Function:
"""Represents an Outlines function.
Functions are a convenient way to encapsulate a prompt template, a language
model and a Pydantic model that define the output structure. Once defined,
the function can be called with arguments that will be used to render the
prompt template.
"""

prompt_template: Callable
model_name: str
schema: Union[str, Callable, object]
generator: Optional["SequenceGenerator"] = None

def init_generator(self):
"""Load the model and initialize the generator."""
model = models.transformers(self.model_name)
self.generator = generate.json(model, self.schema)

def __call__(self, *args, **kwargs):
"""Call the function.
.. warning::
This currently does not support batching.
Parameters
----------
args
Values to pass to the prompt template as positional arguments.
kwargs
Values to pass to the prompt template as keyword arguments.
"""
if self.generator is None:
self.init_generator()

prompt = self.prompt_template(*args, **kwargs)
return self.generator(prompt)
20 changes: 20 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from pydantic import BaseModel

import outlines
from outlines.function import Function


def test_function_basic():
@outlines.prompt
def test_template(text: str):
"""{{ text }}"""

class Foo(BaseModel):
id: int

fn = Function(test_template, "hf-internal-testing/tiny-random-GPTJForCausalLM", Foo)

assert fn.generator is None

result = fn("test")
assert isinstance(result, BaseModel)

0 comments on commit 0c5503a

Please sign in to comment.