Skip to content

Commit

Permalink
Merge branch 'main' into feature/azure-openai
Browse files Browse the repository at this point in the history
# Conflicts:
#	spacy_llm/tests/compat.py
  • Loading branch information
rmitsch committed Oct 5, 2023
2 parents f1bcc2d + fad5253 commit 55f1db0
Show file tree
Hide file tree
Showing 17 changed files with 283 additions and 21 deletions.
4 changes: 3 additions & 1 deletion spacy_llm/models/hf/dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def init_model(self) -> Any:
"""Sets up HF model and needed utilities.
RETURNS (Any): HF model.
"""
return transformers.pipeline(model=self._name, **self._config_init)
return transformers.pipeline(
model=self._name, return_full_text=False, **self._config_init
)

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
"""Queries Dolly HF model.
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/models/hf/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def init_model(self) -> Any:
"text-generation",
model=self._name,
tokenizer=self._tokenizer,
return_full_text=False,
**self._config_init,
)

Expand Down
1 change: 1 addition & 0 deletions spacy_llm/models/hf/llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def init_model(self) -> Any:
"text-generation",
model=self._name,
use_auth_token=True,
return_full_text=False,
**self._config_init,
)

Expand Down
4 changes: 3 additions & 1 deletion spacy_llm/models/hf/openllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove
assert hasattr(self._model, "generate")
return [
self._tokenizer.decode(
self._model.generate(input_ids=tii, **self._config_run)[0],
self._model.generate(input_ids=tii, **self._config_run)[
:, tii.shape[1] :
][0],
)
for tii in tokenized_input_ids
]
Expand Down
12 changes: 7 additions & 5 deletions spacy_llm/models/hf/stablelm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def hf_account(self) -> str:

def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[override]
assert callable(self._tokenizer)
tokenized_prompts = [
self._tokenizer(prompt, return_tensors="pt")
tokenized_input_ids = [
self._tokenizer(prompt, return_tensors="pt").input_ids
for prompt in (
# Add prompt formatting for tuned model.
prompts
Expand All @@ -81,15 +81,17 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # type: ignore[ove
)
]
if self._device:
tokenized_prompts = [tp.to(self._device) for tp in tokenized_prompts]
tokenized_input_ids = [tp.to(self._device) for tp in tokenized_input_ids]

assert hasattr(self._model, "generate")
return [
self._tokenizer.decode(
self._model.generate(**prompt, **self._config_run)[0],
self._model.generate(input_ids=tii, **self._config_run)[
:, tii.shape[1] :
][0],
skip_special_tokens=True,
)
for prompt in tokenized_prompts
for tii in tokenized_input_ids
]

@staticmethod
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/models/rest/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,7 @@ def get_model_names(cls) -> Tuple[str, ...]:
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0613-16k",
"gpt-3.5-turbo-instruct",
# text-davinci
"text-davinci-002",
"text-davinci-003",
Expand Down
18 changes: 14 additions & 4 deletions spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ def openai_gpt_3_5_v2(
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0613-16k",
"gpt-3.5-turbo-instruct",
] = "gpt-3.5-turbo", # noqa: F722,F821
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
Expand All @@ -98,14 +99,18 @@ def openai_gpt_3_5_v2(
"""Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API.
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613-16k"]): Model to use.
name (Literal[
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613-16k", "gpt-3.5-turbo-instruct"
]): Model to use.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-3.5' model
DOCS: https://spacy.io/api/large-language-models#models
"""
return OpenAI(
name=name,
endpoint=Endpoints.CHAT.value,
endpoint=Endpoints.CHAT.value
# gpt-3.5-turbo-instruct runs on the non-chat endpoint, so we use that one by default to allow batching.
if name != "gpt-3.5-turbo-instruct" else Endpoints.NON_CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
Expand All @@ -122,6 +127,7 @@ def openai_gpt_3_5(
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0613-16k",
"gpt-3.5-turbo-instruct",
] = "gpt-3.5-turbo", # noqa: F722,F821
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
Expand All @@ -131,14 +137,18 @@ def openai_gpt_3_5(
"""Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API.
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (Literal["gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613-16k"]): Model to use.
name (Literal[
"gpt-3.5-turbo", "gpt-3.5-turbo-16k", "gpt-3.5-turbo-0613", "gpt-3.5-turbo-0613-16k", "gpt-3.5-turbo-instruct"
]): Model to use.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-3.5' model
DOCS: https://spacy.io/api/large-language-models#models
"""
return OpenAI(
name=name,
endpoint=Endpoints.CHAT.value,
endpoint=Endpoints.CHAT.value
# gpt-3.5-turbo-instruct runs on the non-chat endpoint, so we use that one by default to allow batching.
if name != "gpt-3.5-turbo-instruct" else Endpoints.NON_CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
Expand Down
4 changes: 4 additions & 0 deletions spacy_llm/models/rest/palm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .model import Endpoints, PaLM
from .registry import palm_bison

__all__ = ["palm_bison", "PaLM", "Endpoints"]
113 changes: 113 additions & 0 deletions spacy_llm/models/rest/palm/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized, Tuple

import requests # type: ignore[import]
import srsly # type: ignore[import]
from requests import HTTPError

from ..base import REST


class Endpoints(str, Enum):
TEXT = "https://generativelanguage.googleapis.com/v1beta3/models/{model}:generateText?key={api_key}"
MSG = "https://generativelanguage.googleapis.com/v1beta3/models/{model}:generateMessage?key={api_key}"


class PaLM(REST):
@property
def credentials(self) -> Dict[str, str]:
api_key = os.getenv("PALM_API_KEY")
if api_key is None:
warnings.warn(
"Could not find the API key to access the Cohere API. Ensure you have an API key "
"set up via https://cloud.google.com/docs/authentication/api-keys#rest, then make it available as "
"an environment variable 'PALM_API_KEY'."
)

assert api_key is not None
return {"api_key": api_key}

def _verify_auth(self) -> None:
try:
self(["What's 2+2?"])
except ValueError as err:
if "API key not valid" in str(err):
warnings.warn(
"Authentication with provided API key failed. Please double-check you provided the correct "
"credentials."
)
else:
raise err

def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
headers = {
"Content-Type": "application/json",
"Accept": "application/json",
}
api_responses: List[str] = []
prompts = list(prompts)
url = self._endpoint.format(
model=self._name, api_key=self._credentials["api_key"]
)

def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
r = self.retry(
call_method=requests.post,
url=url,
headers=headers,
json={**json_data, **self._config},
timeout=self._max_request_time,
)
try:
r.raise_for_status()
except HTTPError as ex:
res_content = srsly.json_loads(r.content.decode("utf-8"))
# Include specific error message in exception.
error_message = res_content.get("error", {}).get("message", {})
# Catching other types of HTTPErrors (e.g., "429: too many requests")
raise ValueError(f"Request to PaLM API failed: {error_message}") from ex
response = r.json()

# PaLM returns a 'filter' key when a message was filtered due to safety concerns.
if "filters" in response:
if self._strict:
raise ValueError(f"API call failed: {response}.")
else:
assert isinstance(prompts, Sized)
return {"error": [srsly.json_dumps(response)] * len(prompts)}
return response

# PaLM API currently doesn't accept batch prompts, so we're making
# a request for each iteration. This approach can be prone to rate limit
# errors. In practice, you can adjust _max_request_time so that the
# timeout is larger.
uses_chat = "chat" in self._name
responses = [
_request(
{
"prompt": {"text": prompt}
if not uses_chat
else {"messages": [{"content": prompt}]}
}
)
for prompt in prompts
]
for response in responses:
if "candidates" in response:
# Although you can set the number of candidates in PaLM to be greater than 1, we only need to return a
# single value. In this case, we will just return the very first output.
api_responses.append(
response["candidates"][0].get(
"content" if uses_chat else "output", srsly.json_dumps(response)
)
)
else:
api_responses.append(srsly.json_dumps(response))

return api_responses

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
return "text-bison-001", "chat-bison-001"
42 changes: 42 additions & 0 deletions spacy_llm/models/rest/palm/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import Any, Callable, Dict, Iterable

from confection import SimpleFrozenDict

from ....compat import Literal
from ....registry import registry
from .model import Endpoints, PaLM


@registry.llm_models("spacy.PaLM.v1")
def palm_bison(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0),
name: Literal["chat-bison-001", "text-bison-001"] = "text-bison-001", # noqa: F821
strict: bool = PaLM.DEFAULT_STRICT,
max_tries: int = PaLM.DEFAULT_MAX_TRIES,
interval: float = PaLM.DEFAULT_INTERVAL,
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns Google instance for PaLM Bison model using REST to prompt API.
name (Literal["chat-bison-001", "text-bison-001"]): Model to use.
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance.
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Cohere instance for 'command' model using REST to prompt API.
"""
return PaLM(
name=name,
endpoint=Endpoints.TEXT.value
if name in {"text-bison-001"}
else Endpoints.MSG.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
)
1 change: 1 addition & 0 deletions spacy_llm/tests/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
has_anthropic_key = os.getenv("ANTHROPIC_API_KEY") is not None
has_cohere_key = os.getenv("CO_API_KEY") is not None
has_azure_openai_key = os.getenv("AZURE_OPENAI_KEY") is not None
has_palm_key = os.getenv("PALM_API_KEY") is not None
11 changes: 7 additions & 4 deletions spacy_llm/tests/models/test_dolly.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"name": "dolly-v2-3b",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
"save_io": True,
}

_NLP_CONFIG = """
Expand All @@ -26,6 +27,7 @@
[components.llm]
factory = "llm"
save_io = True
[components.llm.task]
@llm_tasks = "spacy.NoOp.v1"
Expand All @@ -41,12 +43,13 @@
def test_init():
"""Test initialization and simple run."""
nlp = spacy.blank("en")
cfg = copy.deepcopy(_PIPE_CFG)
cfg["model"]["@llm_models"] = "spacy.Dolly.v1"
nlp.add_pipe("llm", config=cfg)
nlp("This is a test.")
nlp.add_pipe("llm", config=_PIPE_CFG)
doc = nlp("This is a test.")
nlp.get_pipe("llm")._model.get_model_names()
torch.cuda.empty_cache()
assert not doc.user_data["llm_io"]["llm"]["response"].startswith(
doc.user_data["llm_io"]["llm"]["prompt"]
)


@pytest.mark.gpu
Expand Down
6 changes: 5 additions & 1 deletion spacy_llm/tests/models/test_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"name": "falcon-rw-1b",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
"save_io": True,
}

_NLP_CONFIG = """
Expand Down Expand Up @@ -43,8 +44,11 @@ def test_init():
nlp = spacy.blank("en")
cfg = copy.deepcopy(_PIPE_CFG)
nlp.add_pipe("llm", config=cfg)
nlp("This is a test.")
doc = nlp("This is a test.")
torch.cuda.empty_cache()
assert not doc.user_data["llm_io"]["llm"]["response"].startswith(
doc.user_data["llm_io"]["llm"]["prompt"]
)


@pytest.mark.gpu
Expand Down
6 changes: 5 additions & 1 deletion spacy_llm/tests/models/test_llama2.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
"name": "Llama-2-7b-hf",
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
"save_io": True,
}

_NLP_CONFIG = """
Expand Down Expand Up @@ -44,8 +45,11 @@ def test_init():
nlp = spacy.blank("en")
cfg = copy.deepcopy(_PIPE_CFG)
nlp.add_pipe("llm", config=cfg)
nlp("This is a test.")
doc = nlp("This is a test.")
torch.cuda.empty_cache()
assert not doc.user_data["llm_io"]["llm"]["response"].startswith(
doc.user_data["llm_io"]["llm"]["prompt"]
)


@pytest.mark.skip(reason="CI runner needs more GPU memory")
Expand Down
Loading

0 comments on commit 55f1db0

Please sign in to comment.