Skip to content

Commit

Permalink
Add support for Azure OpenAI.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Oct 5, 2023
1 parent 63cdee9 commit f1bcc2d
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 23 deletions.
22 changes: 15 additions & 7 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from ...registry import registry

try:
from langchain import base_language # noqa: F401
from langchain import llms # noqa: F401
except (ImportError, AttributeError):
llms = None
Expand All @@ -18,7 +19,7 @@ def __init__(
api: str,
config: Dict[Any, Any],
query: Callable[
["langchain.llms.base.BaseLLM", Iterable[Any]],
["langchain.base_language.BaseLanguageModel", Iterable[Any]],
Iterable[Any],
],
):
Expand All @@ -36,9 +37,11 @@ def __init__(
self._check_installation()

@staticmethod
def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.base.BaseLLM"]]:
def get_type_to_cls_dict() -> Dict[
str, Type["langchain.base_language.BaseLanguageModel"]
]:
"""Returns langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.llms.base.BaseLLM]]): langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.base_language.BaseLanguageModel]]): langchain.llms.type_to_cls_dict.
"""
return langchain.llms.type_to_cls_dict

Expand All @@ -51,10 +54,10 @@ def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]:

@staticmethod
def query_langchain(
model: "langchain.llms.base.BaseLLM", prompts: Iterable[Any]
model: "langchain.base_language.BaseLanguageModel", prompts: Iterable[Any]
) -> Iterable[Any]:
"""Query LangChain model naively.
model (langchain.llms.base.BaseLLM): LangChain model.
model (langchain.base_language.BaseLanguageModel): LangChain model.
prompts (Iterable[Any]): Prompts to execute.
RETURNS (Iterable[Any]): LLM responses.
"""
Expand All @@ -74,7 +77,10 @@ def _langchain_model_maker(class_id: str):
def langchain_model(
name: str,
query: Optional[
Callable[["langchain.llms.base.BaseLLM", Iterable[str]], Iterable[str]]
Callable[
["langchain.base_language.BaseLanguageModel", Iterable[str]],
Iterable[str],
]
] = None,
config: Dict[Any, Any] = SimpleFrozenDict(),
langchain_class_id: str = class_id,
Expand Down Expand Up @@ -117,7 +123,9 @@ def register_models() -> None:

@registry.llm_queries("spacy.CallLangChain.v1")
def query_langchain() -> (
Callable[["langchain.llms.base.BaseLLM", Iterable[Any]], Iterable[Any]]
Callable[
["langchain.base_language.BaseLanguageModel", Iterable[Any]], Iterable[Any]
]
):
"""Returns query Callable for LangChain.
RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]:): Callable executing simple prompts on
Expand Down
3 changes: 2 additions & 1 deletion spacy_llm/models/rest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from . import anthropic, base, cohere, noop, openai
from . import anthropic, azure, base, cohere, noop, openai

__all__ = [
"anthropic",
"azure",
"base",
"cohere",
"openai",
Expand Down
4 changes: 4 additions & 0 deletions spacy_llm/models/rest/azure/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .model import AzureOpenAI
from .registry import azure_openai

__all__ = ["AzureOpenAI", "azure_openai"]
179 changes: 179 additions & 0 deletions spacy_llm/models/rest/azure/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized, Tuple

import requests # type: ignore[import]
import srsly # type: ignore[import]
from requests import HTTPError

from ..base import REST


class ModelType(str, Enum):
COMPLETION = "completions"
CHAT = "chat"


class AzureOpenAI(REST):
def __init__(
self,
name: str,
endpoint: str,
config: Dict[Any, Any],
strict: bool,
max_tries: int,
interval: float,
max_request_time: float,
model_type: ModelType,
api_version: str = "2023-05-15",
):
self._model_type = model_type
self._api_version = api_version
super().__init__(
name=name,
endpoint=endpoint,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
)

@property
def endpoint(self) -> str:
"""Returns fully formed endpoint URL.
RETURNS (str): Fully formed endpoint URL.
"""
return (
self._endpoint
+ ("" if self._endpoint.endswith("/") else "/")
+ f"openai/deployments/{self._name}/{self._model_type.value}"
)

@property
def credentials(self) -> Dict[str, str]:
# Fetch and check the key
api_key = os.getenv("AZURE_OPENAI_KEY")
if api_key is None:
warnings.warn(
"Could not find the API key to access the Azure OpenAI API. Ensure you have an API key "
"set up (see "
"https://learn.microsoft.com/en-us/azure/ai-services/openai/quickstart?pivots=rest-api&tabs=bash#set-up"
", then make it available as an environment variable 'AZURE_OPENAI_KEY'."
)

# Check the access and get a list of available models to verify the model argument (if not None)
# Even if the model is None, this call is used as a healthcheck to verify access.
assert api_key is not None
return {"api-key": api_key}

def _verify_auth(self) -> None:
try:
self(["test"])
except ValueError as err:
raise err
# todo check and redo status codes
# - wrong token
# - ?
# if r.status_code == 422:
# warnings.warn(
# "Could not access api.openai.com -- 422 permission denied."
# "Visit https://platform.openai.com/account/api-keys to check your API keys."
# )
# elif r.status_code != 200:
# if "Incorrect API key" in r.text:
# warnings.warn(
# "Authentication with provided API key failed. Please double-check you provided the correct "
# "credentials."
# )
# else:
# warnings.warn(
# f"Error accessing api.openai.com ({r.status_code}): {r.text}"
# )

def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
headers = {
**self._credentials,
"Content-Type": "application/json",
}
api_responses: List[str] = []
prompts = list(prompts)

def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
r = self.retry(
call_method=requests.post,
url=self.endpoint,
headers=headers,
json={**json_data, **self._config},
timeout=self._max_request_time,
params={"api-version": self._api_version},
)
try:
r.raise_for_status()
except HTTPError as ex:
res_content = srsly.json_loads(r.content.decode("utf-8"))
# Include specific error message in exception.
raise ValueError(
f"Request to Azure OpenAI API failed: "
f"{res_content.get('error', {}).get('message', str(res_content))}"
) from ex
responses = r.json()

# todo check if this is the same
if "error" in responses:
if self._strict:
raise ValueError(f"API call failed: {responses}.")
else:
assert isinstance(prompts, Sized)
return {"error": [srsly.json_dumps(responses)] * len(prompts)}

return responses

# The (Azure) OpenAI API doesn't support batching yet, so we have to send individual requests.
# https://learn.microsoft.com/en-us/answers/questions/1334800/batching-requests-in-azure-openai

if self._model_type == ModelType.CHAT:
# Note: this is yet (2023-10-05) untested, as Azure doesn't seem to allow the deployment of a chat model
# yet.
for prompt in prompts:
responses = _request(
{"messages": [{"role": "user", "content": prompt}]}
)
if "error" in responses:
return responses["error"]

# Process responses.
assert len(responses["choices"]) == 1
response = responses["choices"][0]
api_responses.append(
response.get("message", {}).get(
"content", srsly.json_dumps(response)
)
)

elif self._model_type == ModelType.COMPLETION:
for prompt in prompts:
responses = _request({"prompt": prompt})
if "error" in responses:
return responses["error"]

# Process responses.
assert len(responses["choices"]) == 1
response = responses["choices"][0]
api_responses.append(
response.get("message", {}).get(
"content", srsly.json_dumps(response)
)
)

return api_responses

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
# We treat the deployment name as "model name", hence it can be arbitrary.
return ("",)

def _check_model(self) -> None:
# We treat the deployment name as "model name", hence it can be arbitrary.
pass
55 changes: 55 additions & 0 deletions spacy_llm/models/rest/azure/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from typing import Any, Callable, Dict, Iterable

from confection import SimpleFrozenDict

from ....registry import registry
from .model import AzureOpenAI, ModelType

_DEFAULT_TEMPERATURE = 0.0


@registry.llm_models("spacy.Azure.v1")
def azure_openai(
name: str,
base_url: str,
model_type: ModelType,
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
strict: bool = AzureOpenAI.DEFAULT_STRICT,
max_tries: int = AzureOpenAI.DEFAULT_MAX_TRIES,
interval: float = AzureOpenAI.DEFAULT_INTERVAL,
max_request_time: float = AzureOpenAI.DEFAULT_MAX_REQUEST_TIME,
api_version: str = "2023-05-15",
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns OpenAI instance for 'gpt-4' model using REST to prompt API.
config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Name of the deployment to use. Note that this does not necessarily equal the name of the model used by
that deployment, as deployment names in Azure OpenAI can be arbitrary.
endpoint (str): The URL for your Azure OpenAI endpoint. This is usually something like
"https://{prefix}.openai.azure.com/".
model_type (ModelType): Whether the deployed model is a text completetion model (e. g.
text-davinci-003) or a chat model (e. g. gpt-4).
strict (bool): If True, ValueError is raised if the LLM API returns a malformed response (i. e. any kind of JSON
or other response object that does not conform to the expectation of how a well-formed response object from
this API should look like). If False, the API error responses are returned by __call__(), but no error will
be raised.
max_tries (int): Max. number of tries for API request.
interval (float): Time interval (in seconds) for API retries in seconds. We implement a base 2 exponential backoff
at each retry.
max_request_time (float): Max. time (in seconds) to wait for request to terminate before raising an exception.
api_version (str): API version to use.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model
DOCS: https://spacy.io/api/large-language-models#models
"""
return AzureOpenAI(
name=name,
endpoint=base_url,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
api_version=api_version,
model_type=model_type,
)
18 changes: 4 additions & 14 deletions spacy_llm/tests/compat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
import os

if os.getenv("OPENAI_API_KEY") is None:
has_openai_key = False
else:
has_openai_key = True

if os.getenv("ANTHROPIC_API_KEY") is None:
has_anthropic_key = False
else:
has_anthropic_key = True

if os.getenv("CO_API_KEY") is None:
has_cohere_key = False
else:
has_cohere_key = True
has_openai_key = os.getenv("OPENAI_API_KEY") is not None
has_anthropic_key = os.getenv("ANTHROPIC_API_KEY") is not None
has_cohere_key = os.getenv("CO_API_KEY") is not None
has_azure_openai_key = os.getenv("AZURE_OPENAI_KEY") is not None
19 changes: 19 additions & 0 deletions spacy_llm/tests/models/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,22 @@ def test_initialization():
nlp = spacy.blank("en")
nlp.add_pipe("llm", config=PIPE_CFG)
nlp("This is a test.")


@pytest.mark.external
@pytest.mark.skipif(has_langchain is False, reason="LangChain is not installed")
def test_initialization_azure_openai():
"""Test initialization and simple run with Azure OpenAI models."""
pipe_cfg = {
"model": {
"@llm_models": "langchain.Azure.v1",
"name": "ada",
"config": {"temperature": 0.3},
"query": {"@llm_queries": "spacy.CallLangChain.v1"},
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
}

nlp = spacy.blank("en")
nlp.add_pipe("llm", config=pipe_cfg)
nlp("This is a test.")
26 changes: 25 additions & 1 deletion spacy_llm/tests/models/test_rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from spacy.tokens import Doc

from ...registry import registry
from ..compat import has_openai_key
from ..compat import has_azure_openai_key, has_openai_key

PIPE_CFG = {
"model": {
Expand Down Expand Up @@ -105,3 +105,27 @@ def test_max_time_error_handling():
},
},
)


@pytest.mark.skipif(
has_azure_openai_key is False, reason="OpenAI API key not available"
)
@pytest.mark.external
@pytest.mark.parametrize("deployment_name", ("gpt-35-turbo", "gpt-35-turbo-instruct"))
def test_azure_openai(deployment_name: str):
"""Test initialization and simple run for Azure OpenAI."""
nlp = spacy.blank("en")
_pipe_cfg = {
"model": {
"@llm_models": "spacy.Azure.v1",
"base_url": "https://explosion.openai.azure.com/",
"model_type": "completions",
"name": deployment_name,
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
"save_io": True,
}

cfg = copy.deepcopy(_pipe_cfg)
nlp.add_pipe("llm", config=cfg)
nlp("This is a test.")

0 comments on commit f1bcc2d

Please sign in to comment.