From 4f34714057e1d10f91c51dbcecdd5e74c470db29 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Fri, 8 Dec 2023 14:21:45 +0100 Subject: [PATCH] Load model stored in GitHub --- outlines/function.py | 73 +++++++++++++++++++++++-- pyproject.toml | 4 ++ tests/test_function.py | 117 ++++++++++++++++++++++++++++++++++++++++- 3 files changed, 189 insertions(+), 5 deletions(-) diff --git a/outlines/function.py b/outlines/function.py index 308b2d322..48577be8f 100644 --- a/outlines/function.py +++ b/outlines/function.py @@ -1,10 +1,14 @@ +import importlib.util from dataclasses import dataclass -from typing import TYPE_CHECKING, Callable, Optional, Union +from typing import TYPE_CHECKING, Callable, Optional, Tuple, Union + +import requests from outlines import generate, models if TYPE_CHECKING: from outlines.generate.api import SequenceGenerator + from outlines.prompts import Prompt @dataclass @@ -18,11 +22,19 @@ class Function: """ - prompt_template: Callable - model_name: str + prompt_template: "Prompt" schema: Union[str, Callable, object] + model_name: str generator: Optional["SequenceGenerator"] = None + @classmethod + def from_github(cls, program_path: str, function_name: str = "fn"): + """Load a function stored on GitHub""" + program_content = download_from_github(program_path) + function = extract_function_from_file(program_content, function_name) + + return function + def init_generator(self): """Load the model and initialize the generator.""" model = models.transformers(self.model_name) @@ -48,3 +60,58 @@ def __call__(self, *args, **kwargs): prompt = self.prompt_template(*args, **kwargs) return self.generator(prompt) + + +def download_from_github(short_path: str): + """Download the file in which the function is stored on GitHub.""" + GITHUB_BASE_URL = "https://raw.githubusercontent.com" + BRANCH = "main" + + path = short_path.split("/") + if len(path) < 3: + raise ValueError( + "Please provide a valid path in the form {USERNAME}/{REPO_NAME}/{PATH_TO_FILE}." + ) + elif short_path[-3:] == ".py": + raise ValueError("Do not append the `.py` extension to the program name.") + + username = path[0] + repo = path[1] + path_to_file = path[2:] + + url = "/".join([GITHUB_BASE_URL, username, repo, BRANCH] + path_to_file) + ".py" + result = requests.get(url) + + if result.status_code == 200: + return result.text + elif result.status_code == 404: + raise ValueError( + f"Program could not be found at {url}. Please make sure you entered the GitHub username, repository name and path to the program correctly." + ) + else: + result.raise_for_status() + + +def extract_function_from_file(content: str, function_name: str) -> Tuple[Callable]: + """Extract a function object from a downloaded file.""" + + spec = importlib.util.spec_from_loader( + "outlines_function", loader=None, origin="github" + ) + if spec is not None: + module = importlib.util.module_from_spec(spec) + exec(content, module.__dict__) + + try: + fn = getattr(module, function_name) + except AttributeError: + raise AttributeError( + "Could not find an `outlines.Function` instance in the remote file. Make sure that the path you specified is correct." + ) + + if not isinstance(fn, module.outlines.Function): + raise TypeError( + f"The `{function_name}` variable in the program must be an instance of `outlines.Function`" + ) + + return fn diff --git a/pyproject.toml b/pyproject.toml index f060321fa..57ee4d2bd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ dependencies = [ "joblib", "referencing", "jsonschema", + "requests", ] dynamic = ["version"] @@ -52,6 +53,7 @@ test = [ "accelerate", "beartype<0.16.0", "datasets", + "responses", ] [project.urls] @@ -111,6 +113,8 @@ module = [ "interegular.*", "datasets.*", "numba.*", + "requests.*", + "responses.*", ] ignore_missing_imports = true diff --git a/tests/test_function.py b/tests/test_function.py index 3a734cef8..24e132d42 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -1,7 +1,10 @@ +import pytest +import responses from pydantic import BaseModel +from requests.exceptions import HTTPError import outlines -from outlines.function import Function +from outlines.function import Function, download_from_github, extract_function_from_file def test_function_basic(): @@ -12,9 +15,119 @@ def test_template(text: str): class Foo(BaseModel): id: int - fn = Function(test_template, "hf-internal-testing/tiny-random-GPTJForCausalLM", Foo) + fn = Function(test_template, Foo, "hf-internal-testing/tiny-random-GPTJForCausalLM") assert fn.generator is None result = fn("test") assert isinstance(result, BaseModel) + + +def test_download_from_github_invalid(): + with pytest.raises(ValueError, match="Please provide"): + download_from_github("outlines/program") + + with pytest.raises(ValueError, match="Do not append"): + download_from_github("outlines-dev/outlines/program.py") + + +@responses.activate +def test_download_from_github_success(): + responses.add( + responses.GET, + "https://raw.githubusercontent.com/outlines-dev/outlines/main/program.py", + body="import outlines\n", + status=200, + ) + + file = download_from_github("outlines-dev/outlines/program") + assert file == "import outlines\n" + + responses.add( + responses.GET, + "https://raw.githubusercontent.com/outlines-dev/outlines/main/foo/bar/program.py", + body="import outlines\n", + status=200, + ) + + file = download_from_github("outlines-dev/outlines/foo/bar/program") + assert file == "import outlines\n" + + +@responses.activate +def test_download_from_github_error(): + responses.add( + responses.GET, + "https://raw.githubusercontent.com/foo/bar/main/program.py", + json={"error": "not found"}, + status=404, + ) + + with pytest.raises(ValueError, match="Program could not be found at"): + download_from_github("foo/bar/program") + + responses.add( + responses.GET, + "https://raw.githubusercontent.com/foo/bar/main/program.py", + json={"error": "Internal Server Error"}, + status=500, + ) + + with pytest.raises(HTTPError, match="500 Server Error"): + download_from_github("foo/bar/program") + + +def test_extract_function_from_file(): + content = """ +import outlines +from pydantic import BaseModel + +model = "gpt2" + + +@outlines.prompt +def prompt(): + '''Hello''' + + +class User(BaseModel): + id: int + name: str + + +function = outlines.Function( + prompt, + User, + "gpt2", +) + """ + + fn = extract_function_from_file(content, "function") + assert ( + str(type(fn)) == "" + ) # because imported via `exec` + + +def test_extract_function_from_file_no_function(): + content = """ +import outlines +from pydantic import BaseModel + +@outlines.prompt +def prompt(): + '''Hello''' + + +class User(BaseModel): + id: int + name: str + +program = outlines.Function( + prompt, + User, + "gpt2", +) + """ + + with pytest.raises(AttributeError, match="Could not find"): + extract_function_from_file(content, "function")