-
-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into feature/azure-openai
# Conflicts: # spacy_llm/tests/compat.py
- Loading branch information
Showing
17 changed files
with
283 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .model import Endpoints, PaLM | ||
from .registry import palm_bison | ||
|
||
__all__ = ["palm_bison", "PaLM", "Endpoints"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import os | ||
import warnings | ||
from enum import Enum | ||
from typing import Any, Dict, Iterable, List, Sized, Tuple | ||
|
||
import requests # type: ignore[import] | ||
import srsly # type: ignore[import] | ||
from requests import HTTPError | ||
|
||
from ..base import REST | ||
|
||
|
||
class Endpoints(str, Enum): | ||
TEXT = "https://generativelanguage.googleapis.com/v1beta3/models/{model}:generateText?key={api_key}" | ||
MSG = "https://generativelanguage.googleapis.com/v1beta3/models/{model}:generateMessage?key={api_key}" | ||
|
||
|
||
class PaLM(REST): | ||
@property | ||
def credentials(self) -> Dict[str, str]: | ||
api_key = os.getenv("PALM_API_KEY") | ||
if api_key is None: | ||
warnings.warn( | ||
"Could not find the API key to access the Cohere API. Ensure you have an API key " | ||
"set up via https://cloud.google.com/docs/authentication/api-keys#rest, then make it available as " | ||
"an environment variable 'PALM_API_KEY'." | ||
) | ||
|
||
assert api_key is not None | ||
return {"api_key": api_key} | ||
|
||
def _verify_auth(self) -> None: | ||
try: | ||
self(["What's 2+2?"]) | ||
except ValueError as err: | ||
if "API key not valid" in str(err): | ||
warnings.warn( | ||
"Authentication with provided API key failed. Please double-check you provided the correct " | ||
"credentials." | ||
) | ||
else: | ||
raise err | ||
|
||
def __call__(self, prompts: Iterable[str]) -> Iterable[str]: | ||
headers = { | ||
"Content-Type": "application/json", | ||
"Accept": "application/json", | ||
} | ||
api_responses: List[str] = [] | ||
prompts = list(prompts) | ||
url = self._endpoint.format( | ||
model=self._name, api_key=self._credentials["api_key"] | ||
) | ||
|
||
def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: | ||
r = self.retry( | ||
call_method=requests.post, | ||
url=url, | ||
headers=headers, | ||
json={**json_data, **self._config}, | ||
timeout=self._max_request_time, | ||
) | ||
try: | ||
r.raise_for_status() | ||
except HTTPError as ex: | ||
res_content = srsly.json_loads(r.content.decode("utf-8")) | ||
# Include specific error message in exception. | ||
error_message = res_content.get("error", {}).get("message", {}) | ||
# Catching other types of HTTPErrors (e.g., "429: too many requests") | ||
raise ValueError(f"Request to PaLM API failed: {error_message}") from ex | ||
response = r.json() | ||
|
||
# PaLM returns a 'filter' key when a message was filtered due to safety concerns. | ||
if "filters" in response: | ||
if self._strict: | ||
raise ValueError(f"API call failed: {response}.") | ||
else: | ||
assert isinstance(prompts, Sized) | ||
return {"error": [srsly.json_dumps(response)] * len(prompts)} | ||
return response | ||
|
||
# PaLM API currently doesn't accept batch prompts, so we're making | ||
# a request for each iteration. This approach can be prone to rate limit | ||
# errors. In practice, you can adjust _max_request_time so that the | ||
# timeout is larger. | ||
uses_chat = "chat" in self._name | ||
responses = [ | ||
_request( | ||
{ | ||
"prompt": {"text": prompt} | ||
if not uses_chat | ||
else {"messages": [{"content": prompt}]} | ||
} | ||
) | ||
for prompt in prompts | ||
] | ||
for response in responses: | ||
if "candidates" in response: | ||
# Although you can set the number of candidates in PaLM to be greater than 1, we only need to return a | ||
# single value. In this case, we will just return the very first output. | ||
api_responses.append( | ||
response["candidates"][0].get( | ||
"content" if uses_chat else "output", srsly.json_dumps(response) | ||
) | ||
) | ||
else: | ||
api_responses.append(srsly.json_dumps(response)) | ||
|
||
return api_responses | ||
|
||
@classmethod | ||
def get_model_names(cls) -> Tuple[str, ...]: | ||
return "text-bison-001", "chat-bison-001" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import Any, Callable, Dict, Iterable | ||
|
||
from confection import SimpleFrozenDict | ||
|
||
from ....compat import Literal | ||
from ....registry import registry | ||
from .model import Endpoints, PaLM | ||
|
||
|
||
@registry.llm_models("spacy.PaLM.v1") | ||
def palm_bison( | ||
config: Dict[Any, Any] = SimpleFrozenDict(temperature=0), | ||
name: Literal["chat-bison-001", "text-bison-001"] = "text-bison-001", # noqa: F821 | ||
strict: bool = PaLM.DEFAULT_STRICT, | ||
max_tries: int = PaLM.DEFAULT_MAX_TRIES, | ||
interval: float = PaLM.DEFAULT_INTERVAL, | ||
max_request_time: float = PaLM.DEFAULT_MAX_REQUEST_TIME, | ||
) -> Callable[[Iterable[str]], Iterable[str]]: | ||
"""Returns Google instance for PaLM Bison model using REST to prompt API. | ||
name (Literal["chat-bison-001", "text-bison-001"]): Model to use. | ||
config (Dict[Any, Any]): LLM config arguments passed on to the initialization of the model instance. | ||
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON | ||
or other response object that does not conform to the expectation of how a well-formed response object from | ||
this API should look like). If False, the API error responses are returned by __call__(), but no error will | ||
be raised. | ||
max_tries (int): Max. number of tries for API request. | ||
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff | ||
at each retry. | ||
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. | ||
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): Cohere instance for 'command' model using REST to prompt API. | ||
""" | ||
return PaLM( | ||
name=name, | ||
endpoint=Endpoints.TEXT.value | ||
if name in {"text-bison-001"} | ||
else Endpoints.MSG.value, | ||
config=config, | ||
strict=strict, | ||
max_tries=max_tries, | ||
interval=interval, | ||
max_request_time=max_request_time, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.