-
-
Notifications
You must be signed in to change notification settings - Fork 92
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
303 additions
and
23 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .model import AzureOpenAI | ||
from .registry import azure_openai | ||
|
||
__all__ = ["AzureOpenAI", "azure_openai"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import os | ||
import warnings | ||
from enum import Enum | ||
from typing import Any, Dict, Iterable, List, Sized, Tuple | ||
|
||
import requests # type: ignore[import] | ||
import srsly # type: ignore[import] | ||
from requests import HTTPError | ||
|
||
from ..base import REST | ||
|
||
|
||
class ModelType(str, Enum): | ||
COMPLETION = "completions" | ||
CHAT = "chat" | ||
|
||
|
||
class AzureOpenAI(REST): | ||
def __init__( | ||
self, | ||
name: str, | ||
endpoint: str, | ||
config: Dict[Any, Any], | ||
strict: bool, | ||
max_tries: int, | ||
interval: float, | ||
max_request_time: float, | ||
model_type: ModelType, | ||
api_version: str = "2023-05-15", | ||
): | ||
self._model_type = model_type | ||
self._api_version = api_version | ||
super().__init__( | ||
name=name, | ||
endpoint=endpoint, | ||
config=config, | ||
strict=strict, | ||
max_tries=max_tries, | ||
interval=interval, | ||
max_request_time=max_request_time, | ||
) | ||
|
||
@property | ||
def endpoint(self) -> str: | ||
"""Returns fully formed endpoint URL. | ||
RETURNS (str): Fully formed endpoint URL. | ||
""" | ||
return ( | ||
self._endpoint | ||
+ ("" if self._endpoint.endswith("/") else "/") | ||
+ f"openai/deployments/{self._name}/{self._model_type.value}" | ||
) | ||
|
||
@property | ||
def credentials(self) -> Dict[str, str]: | ||
# Fetch and check the key | ||
api_key = os.getenv("AZURE_OPENAI_KEY") | ||
if api_key is None: | ||
warnings.warn( | ||
"Could not find the API key to access the Azure OpenAI API. Ensure you have an API key " | ||
"set up (see " | ||
"https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?pivots=rest-api&tabs=bash#set-up" | ||
", then make it available as an environment variable 'AZURE_OPENAI_KEY'." | ||
) | ||
|
||
# Check the access and get a list of available models to verify the model argument (if not None) | ||
# Even if the model is None, this call is used as a healthcheck to verify access. | ||
assert api_key is not None | ||
return {"api-key": api_key} | ||
|
||
def _verify_auth(self) -> None: | ||
try: | ||
self(["test"]) | ||
except ValueError as err: | ||
raise err | ||
# todo check and redo status codes | ||
# - wrong token | ||
# - ? | ||
# if r.status_code == 422: | ||
# warnings.warn( | ||
# "Could not access api.openai.com -- 422 permission denied." | ||
# "Visit https://platform.openai.com/account/api-keys to check your API keys." | ||
# ) | ||
# elif r.status_code != 200: | ||
# if "Incorrect API key" in r.text: | ||
# warnings.warn( | ||
# "Authentication with provided API key failed. Please double-check you provided the correct " | ||
# "credentials." | ||
# ) | ||
# else: | ||
# warnings.warn( | ||
# f"Error accessing api.openai.com ({r.status_code}): {r.text}" | ||
# ) | ||
|
||
def __call__(self, prompts: Iterable[str]) -> Iterable[str]: | ||
headers = { | ||
**self._credentials, | ||
"Content-Type": "application/json", | ||
} | ||
api_responses: List[str] = [] | ||
prompts = list(prompts) | ||
|
||
def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: | ||
r = self.retry( | ||
call_method=requests.post, | ||
url=self.endpoint, | ||
headers=headers, | ||
json={**json_data, **self._config}, | ||
timeout=self._max_request_time, | ||
params={"api-version": self._api_version}, | ||
) | ||
try: | ||
r.raise_for_status() | ||
except HTTPError as ex: | ||
res_content = srsly.json_loads(r.content.decode("utf-8")) | ||
# Include specific error message in exception. | ||
raise ValueError( | ||
f"Request to Azure OpenAI API failed: " | ||
f"{res_content.get('error', {}).get('message', str(res_content))}" | ||
) from ex | ||
responses = r.json() | ||
|
||
# todo check if this is the same | ||
if "error" in responses: | ||
if self._strict: | ||
raise ValueError(f"API call failed: {responses}.") | ||
else: | ||
assert isinstance(prompts, Sized) | ||
return {"error": [srsly.json_dumps(responses)] * len(prompts)} | ||
|
||
return responses | ||
|
||
# The (Azure) OpenAI API doesn't support batching yet, so we have to send individual requests. | ||
# https://learn.microsoft.com/en-us/answers/questions/1334800/batching-requests-in-azure-openai | ||
|
||
if self._model_type == ModelType.CHAT: | ||
# Note: this is yet (2023-10-05) untested, as Azure doesn't seem to allow the deployment of a chat model | ||
# yet. | ||
for prompt in prompts: | ||
responses = _request( | ||
{"messages": [{"role": "user", "content": prompt}]} | ||
) | ||
if "error" in responses: | ||
return responses["error"] | ||
|
||
# Process responses. | ||
assert len(responses["choices"]) == 1 | ||
response = responses["choices"][0] | ||
api_responses.append( | ||
response.get("message", {}).get( | ||
"content", srsly.json_dumps(response) | ||
) | ||
) | ||
|
||
elif self._model_type == ModelType.COMPLETION: | ||
for prompt in prompts: | ||
responses = _request({"prompt": prompt}) | ||
if "error" in responses: | ||
return responses["error"] | ||
|
||
# Process responses. | ||
assert len(responses["choices"]) == 1 | ||
response = responses["choices"][0] | ||
api_responses.append( | ||
response.get("message", {}).get( | ||
"content", srsly.json_dumps(response) | ||
) | ||
) | ||
|
||
return api_responses | ||
|
||
@classmethod | ||
def get_model_names(cls) -> Tuple[str, ...]: | ||
# We treat the deployment name as "model name", hence it can be arbitrary. | ||
return ("",) | ||
|
||
def _check_model(self) -> None: | ||
# We treat the deployment name as "model name", hence it can be arbitrary. | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from typing import Any, Callable, Dict, Iterable | ||
|
||
from confection import SimpleFrozenDict | ||
|
||
from ....registry import registry | ||
from .model import AzureOpenAI, ModelType | ||
|
||
_DEFAULT_TEMPERATURE = 0.0 | ||
|
||
|
||
@registry.llm_models("spacy.Azure.v1") | ||
def azure_openai( | ||
name: str, | ||
base_url: str, | ||
model_type: ModelType, | ||
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), | ||
strict: bool = AzureOpenAI.DEFAULT_STRICT, | ||
max_tries: int = AzureOpenAI.DEFAULT_MAX_TRIES, | ||
interval: float = AzureOpenAI.DEFAULT_INTERVAL, | ||
max_request_time: float = AzureOpenAI.DEFAULT_MAX_REQUEST_TIME, | ||
api_version: str = "2023-05-15", | ||
) -> Callable[[Iterable[str]], Iterable[str]]: | ||
"""Returns OpenAI instance for 'gpt-4' model using REST to prompt API. | ||
config (Dict[Any, Any]): LLM config passed on to the model's initialization. | ||
name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the model used by | ||
that deployment, as deployment names in Azure OpenAI can be arbitrary. | ||
endpoint (str): The URL for your Azure OpenAI endpoint. This is usually something like | ||
"https://{prefix}.openai.azure.com/". | ||
model_type (ModelType): Whether the deployed model is a text completetion model (e. g. | ||
text-davinci-003) or a chat model (e. g. gpt-4). | ||
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON | ||
or other response object that does not conform to the expectation of how a well-formed response object from | ||
this API should look like). If False, the API error responses are returned by __call__(), but no error will | ||
be raised. | ||
max_tries (int): Max. number of tries for API request. | ||
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff | ||
at each retry. | ||
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception. | ||
api_version (str): API version to use. | ||
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model | ||
DOCS: https://spacy.io/api/large-language-models#models | ||
""" | ||
return AzureOpenAI( | ||
name=name, | ||
endpoint=base_url, | ||
config=config, | ||
strict=strict, | ||
max_tries=max_tries, | ||
interval=interval, | ||
max_request_time=max_request_time, | ||
api_version=api_version, | ||
model_type=model_type, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,6 @@ | ||
import os | ||
|
||
if os.getenv("OPENAI_API_KEY") is None: | ||
has_openai_key = False | ||
else: | ||
has_openai_key = True | ||
|
||
if os.getenv("ANTHROPIC_API_KEY") is None: | ||
has_anthropic_key = False | ||
else: | ||
has_anthropic_key = True | ||
|
||
if os.getenv("CO_API_KEY") is None: | ||
has_cohere_key = False | ||
else: | ||
has_cohere_key = True | ||
has_openai_key = os.getenv("OPENAI_API_KEY") is not None | ||
has_anthropic_key = os.getenv("ANTHROPIC_API_KEY") is not None | ||
has_cohere_key = os.getenv("CO_API_KEY") is not None | ||
has_azure_openai_key = os.getenv("AZURE_OPENAI_KEY") is not None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters