Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Define end-to-end applications #413

Merged
merged 2 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions outlines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import outlines.text.generate
from outlines.base import vectorize
from outlines.caching import clear_cache, disable_cache, get_cache
from outlines.function import Function
from outlines.prompts import prompt

__all__ = [
Expand Down
117 changes: 117 additions & 0 deletions outlines/function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import importlib.util
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

import requests

from outlines import generate, models

if TYPE_CHECKING:
from outlines.generate.api import SequenceGenerator
from outlines.prompts import Prompt


@dataclass
class Function:
"""Represents an Outlines function.

Functions are a convenient way to encapsulate a prompt template, a language
model and a Pydantic model that define the output structure. Once defined,
the function can be called with arguments that will be used to render the
prompt template.

"""

prompt_template: "Prompt"
schema: Union[str, Callable, object]
model_name: str
generator: Optional["SequenceGenerator"] = None

@classmethod
def from_github(cls, program_path: str, function_name: str = "fn"):
"""Load a function stored on GitHub"""
program_content = download_from_github(program_path)
function = extract_function_from_file(program_content, function_name)

return function

def init_generator(self):
"""Load the model and initialize the generator."""
model = models.transformers(self.model_name)
self.generator = generate.json(model, self.schema)

def __call__(self, *args, **kwargs):
"""Call the function.

.. warning::

This currently does not support batching.

Parameters
----------
args
Values to pass to the prompt template as positional arguments.
kwargs
Values to pass to the prompt template as keyword arguments.

"""
if self.generator is None:
self.init_generator()

prompt = self.prompt_template(*args, **kwargs)
return self.generator(prompt)


def download_from_github(short_path: str):
"""Download the file in which the function is stored on GitHub."""
GITHUB_BASE_URL = "https://raw.githubusercontent.com"
BRANCH = "main"

path = short_path.split("/")
if len(path) < 3:
raise ValueError(
"Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}."
)
elif short_path[-3:] == ".py":
raise ValueError("Do not append the `.py` extension to the program name.")

username = path[0]
repo = path[1]
path_to_file = path[2:]

url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py"
rlouf marked this conversation as resolved.
Show resolved Hide resolved
result = requests.get(url)

if result.status_code == 200:
return result.text
elif result.status_code == 404:
raise ValueError(
f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly."
)
else:
result.raise_for_status()


def extract_function_from_file(content: str, function_name: str) -> Tuple[Callable]:
"""Extract a function object from a downloaded file."""

spec = importlib.util.spec_from_loader(
"outlines_function", loader=None, origin="github"
)
if spec is not None:
module = importlib.util.module_from_spec(spec)
exec(content, module.__dict__)

try:
fn = getattr(module, function_name)
except AttributeError:
raise AttributeError(
"Could not find an `outlines.Function` instance in the remote file. Make sure that the path you specified is correct."
)

if not isinstance(fn, module.outlines.Function):
raise TypeError(
f"The `{function_name}` variable in the program must be an instance of `outlines.Function`"
)

return fn
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"joblib",
"referencing",
"jsonschema",
"requests",
]
dynamic = ["version"]

Expand All @@ -52,6 +53,7 @@ test = [
"accelerate",
"beartype<0.16.0",
"datasets",
"responses",
]

[project.urls]
Expand Down Expand Up @@ -111,6 +113,8 @@ module = [
"interegular.*",
"datasets.*",
"numba.*",
"requests.*",
"responses.*",
]
ignore_missing_imports = true

Expand Down
133 changes: 133 additions & 0 deletions tests/test_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
import pytest
import responses
from pydantic import BaseModel
from requests.exceptions import HTTPError

import outlines
from outlines.function import Function, download_from_github, extract_function_from_file


def test_function_basic():
@outlines.prompt
def test_template(text: str):
"""{{ text }}"""

class Foo(BaseModel):
id: int

fn = Function(test_template, Foo, "hf-internal-testing/tiny-random-GPTJForCausalLM")

assert fn.generator is None

result = fn("test")
assert isinstance(result, BaseModel)


def test_download_from_github_invalid():
with pytest.raises(ValueError, match="Please provide"):
download_from_github("outlines/program")

with pytest.raises(ValueError, match="Do not append"):
download_from_github("outlines-dev/outlines/program.py")


@responses.activate
def test_download_from_github_success():
responses.add(
responses.GET,
"https://raw.githubusercontent.com/outlines-dev/outlines/main/program.py",
body="import outlines\n",
status=200,
)

file = download_from_github("outlines-dev/outlines/program")
assert file == "import outlines\n"

responses.add(
responses.GET,
"https://raw.githubusercontent.com/outlines-dev/outlines/main/foo/bar/program.py",
body="import outlines\n",
status=200,
)

file = download_from_github("outlines-dev/outlines/foo/bar/program")
assert file == "import outlines\n"


@responses.activate
def test_download_from_github_error():
responses.add(
responses.GET,
"https://raw.githubusercontent.com/foo/bar/main/program.py",
json={"error": "not found"},
status=404,
)

with pytest.raises(ValueError, match="Program could not be found at"):
download_from_github("foo/bar/program")

responses.add(
responses.GET,
"https://raw.githubusercontent.com/foo/bar/main/program.py",
json={"error": "Internal Server Error"},
status=500,
)

with pytest.raises(HTTPError, match="500 Server Error"):
download_from_github("foo/bar/program")


def test_extract_function_from_file():
content = """
import outlines
from pydantic import BaseModel

model = "gpt2"


@outlines.prompt
def prompt():
'''Hello'''


class User(BaseModel):
id: int
name: str


function = outlines.Function(
prompt,
User,
"gpt2",
)
"""

fn = extract_function_from_file(content, "function")
assert (
str(type(fn)) == "<class 'outlines.function.Function'>"
) # because imported via `exec`


def test_extract_function_from_file_no_function():
content = """
import outlines
from pydantic import BaseModel

@outlines.prompt
def prompt():
'''Hello'''


class User(BaseModel):
id: int
name: str

program = outlines.Function(
prompt,
User,
"gpt2",
)
"""

with pytest.raises(AttributeError, match="Could not find"):
extract_function_from_file(content, "function")