Skip to content

Commit

Permalink
Load model stored in GitHub
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Dec 11, 2023
1 parent 27e5b4a commit 8cf9192
Show file tree
Hide file tree
Showing 3 changed files with 191 additions and 5 deletions.
77 changes: 74 additions & 3 deletions outlines/function.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
import importlib.util
import types
from dataclasses import dataclass
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union

import requests

from outlines import generate, models

if TYPE_CHECKING:
from outlines.generate.api import SequenceGenerator
from outlines.prompts import Prompt


@dataclass
Expand All @@ -18,11 +23,20 @@ class Function:
"""

prompt_template: Callable
model_name: str
prompt_template: "Prompt"
schema: Union[str, Callable, object]
model_name: str
generator: Optional["SequenceGenerator"] = None

@classmethod
def from_github(cls, program_path: str, function_name: str = "fn"):
"""Load a function stored on GitHub"""
program_content = download_from_github(program_path)
_, function = extract_function_from_file(
program_content, function_name, program_path
)
return function

def init_generator(self):
"""Load the model and initialize the generator."""
model = models.transformers(self.model_name)
Expand All @@ -48,3 +62,60 @@ def __call__(self, *args, **kwargs):

prompt = self.prompt_template(*args, **kwargs)
return self.generator(prompt)


def download_from_github(short_path: str):
"""Download the file in which the function is stored on GitHub."""
GITHUB_BASE_URL = "https://raw.githubusercontent.com"
BRANCH = "main"

path = short_path.split("/")
if len(path) < 3:
raise ValueError(
"Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}."
)
elif short_path[-3:] == ".py":
raise ValueError("Do not append the `.py` extension to the program name.")

username = path[0]
repo = path[1]
path_to_file = path[2:]

url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py"
result = requests.get(url)

if result.status_code == 200:
return result.text
elif result.status_code == 404:
raise ValueError(
f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly."
)
else:
result.raise_for_status()


def extract_function_from_file(
content: str, function_name: str, program_path: str = ""
) -> Tuple[types.ModuleType, Callable]:
"""Extract a function object from a downloaded file."""

spec = importlib.util.spec_from_loader(
"outlines_function", loader=None, origin=program_path
)
if spec is not None:
module = importlib.util.module_from_spec(spec)
exec(content, module.__dict__)

try:
fn = getattr(module, function_name)
if not isinstance(fn, module.outlines.Function):
raise TypeError(
f"The `{function_name}` variable in the program must be an instance of `outlines.Function`"
)

return module, fn

except AttributeError:
raise AttributeError(
f"Could not find an `outlines.Function` instance in the remote file. Make sure that an `instance` called '{function_name}' is present in the file located at {program_path}"
)
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ dependencies = [
"joblib",
"referencing",
"jsonschema",
"requests",
]
dynamic = ["version"]

Expand All @@ -52,6 +53,7 @@ test = [
"accelerate",
"beartype<0.16.0",
"datasets",
"responses",
]

[project.urls]
Expand Down Expand Up @@ -111,6 +113,8 @@ module = [
"interegular.*",
"datasets.*",
"numba.*",
"requests.*",
"responses.*",
]
ignore_missing_imports = true

Expand Down
115 changes: 113 additions & 2 deletions tests/test_function.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import pytest
import responses
from pydantic import BaseModel
from requests.exceptions import HTTPError

import outlines
from outlines.function import Function
from outlines.function import Function, download_from_github, extract_function_from_file


def test_function_basic():
Expand All @@ -12,9 +15,117 @@ def test_template(text: str):
class Foo(BaseModel):
id: int

fn = Function(test_template, "hf-internal-testing/tiny-random-GPTJForCausalLM", Foo)
fn = Function(test_template, Foo, "hf-internal-testing/tiny-random-GPTJForCausalLM")

assert fn.generator is None

result = fn("test")
assert isinstance(result, BaseModel)


def test_download_from_github_invalid():
with pytest.raises(ValueError, match="Please provide"):
download_from_github("outlines/program")

with pytest.raises(ValueError, match="Do not append"):
download_from_github("outlines-dev/outlines/program.py")


@responses.activate
def test_download_from_github_success():
responses.add(
responses.GET,
"https://raw.githubusercontent.com/outlines-dev/outlines/main/program.py",
body="import outlines\n",
status=200,
)

file = download_from_github("outlines-dev/outlines/program")
assert file == "import outlines\n"

responses.add(
responses.GET,
"https://raw.githubusercontent.com/outlines-dev/outlines/main/foo/bar/program.py",
body="import outlines\n",
status=200,
)

file = download_from_github("outlines-dev/outlines/foo/bar/program")
assert file == "import outlines\n"


@responses.activate
def test_download_from_github_error():
responses.add(
responses.GET,
"https://raw.githubusercontent.com/foo/bar/main/program.py",
json={"error": "not found"},
status=404,
)

with pytest.raises(ValueError, match="Program could not be found at"):
download_from_github("foo/bar/program")

responses.add(
responses.GET,
"https://raw.githubusercontent.com/foo/bar/main/program.py",
json={"error": "Internal Server Error"},
status=500,
)

with pytest.raises(HTTPError, match="500 Server Error"):
download_from_github("foo/bar/program")


def test_extract_function_from_file():
content = """
import outlines
from pydantic import BaseModel
model = "gpt2"
@outlines.prompt
def prompt():
'''Hello'''
class User(BaseModel):
id: int
name: str
function = outlines.Function(
prompt,
User,
"gpt2",
)
"""

module, fn = extract_function_from_file(content, "function")
assert isinstance(fn, module.outlines.Function)


def test_extract_function_from_file_no_function():
content = """
import outlines
from pydantic import BaseModel
@outlines.prompt
def prompt():
'''Hello'''
class User(BaseModel):
id: int
name: str
program = outlines.Function(
prompt,
User,
"gpt2",
)
"""

with pytest.raises(AttributeError, match="Could not find"):
extract_function_from_file(content, "function")

0 comments on commit 8cf9192

Please sign in to comment.