Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove model name check for REST models #356

Merged
merged 7 commits into from
Nov 8, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 1 addition & 24 deletions spacy_llm/models/rest/anthropic/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized, Tuple
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
Expand Down Expand Up @@ -107,26 +107,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:

assert len(api_responses) == len(prompts)
return api_responses

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
return (
# claude-2
"claude-2",
"claude-2-100k",
# claude-1
"claude-1",
"claude-1-100k",
# claude-instant-1
"claude-instant-1",
"claude-instant-1-100k",
# claude-instant-1.1
"claude-instant-1.1",
"claude-instant-1.1-100k",
# claude-1.3
"claude-1.3",
"claude-1.3-100k",
# others
"claude-1.0",
"claude-1.2",
)
11 changes: 1 addition & 10 deletions spacy_llm/models/rest/azure/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized, Tuple
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
Expand Down Expand Up @@ -147,12 +147,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
api_responses.append(response.get("text", srsly.json_dumps(response)))

return api_responses

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
# We treat the deployment name as "model name", hence it can be arbitrary.
return ("",)

def _check_model(self) -> None:
# We treat the deployment name as "model name", hence it can be arbitrary.
pass
17 changes: 1 addition & 16 deletions spacy_llm/models/rest/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import abc
import time
from enum import Enum
from typing import Any, Callable, Dict, Iterable, Optional, Tuple
from typing import Any, Callable, Dict, Iterable, Optional

import requests # type: ignore
from requests import ConnectTimeout, ReadTimeout
Expand Down Expand Up @@ -61,30 +61,15 @@ def __init__(
assert self._interval > 0
assert self._max_request_time > 0

self._check_model()
self._verify_auth()

def _check_model(self) -> None:
"""Checks whether model is supported. Raises if it isn't."""
if self._name not in self.get_model_names():
raise ValueError(
f"Model '{self._name}' is not supported - select one of {self.get_model_names()} instead"
)

@abc.abstractmethod
def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
"""Executes prompts on specified API.
prompts (Iterable[str]): Prompts to execute.
RETURNS (Iterable[str]): API responses.
"""

@classmethod
@abc.abstractmethod
def get_model_names(cls) -> Tuple[str, ...]:
"""Names of supported models.
RETURNS (Tuple[str]): Names of supported models.
"""

@property
@abc.abstractmethod
def credentials(self) -> Dict[str, str]:
Expand Down
6 changes: 1 addition & 5 deletions spacy_llm/models/rest/cohere/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized, Tuple
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
Expand Down Expand Up @@ -111,7 +111,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
else:
api_responses.append(srsly.json_dumps(response))
return api_responses

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
return "command", "command-light", "command-light-nightly", "command-nightly"
6 changes: 1 addition & 5 deletions spacy_llm/models/rest/noop/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import time
from typing import Dict, Iterable, Tuple
from typing import Dict, Iterable

from ..base import REST

Expand Down Expand Up @@ -33,7 +33,3 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]:
# Assume time penalty for API calls.
time.sleep(NoOpModel._CALL_TIMEOUT)
return [_NOOP_RESPONSE] * len(list(prompts))

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
return ("NoOp",)
30 changes: 1 addition & 29 deletions spacy_llm/models/rest/openai/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized, Tuple
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
Expand Down Expand Up @@ -140,31 +140,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
api_responses.append(srsly.json_dumps(response))

return api_responses

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
return (
# gpt-4
"gpt-4",
"gpt-4-0314",
"gpt-4-32k",
"gpt-4-32k-0314",
# gpt-3.5
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0613-16k",
"gpt-3.5-turbo-instruct",
# text-davinci
"text-davinci-002",
"text-davinci-003",
# others
"code-davinci-002",
"text-curie-001",
"text-babbage-001",
"text-ada-001",
"davinci",
"curie",
"babbage",
"ada",
)
65 changes: 63 additions & 2 deletions spacy_llm/models/rest/openai/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,35 @@
"""


@registry.llm_models("spacy.GPT-4.v3")
def openai_gpt_4_v3(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
name: str = "gpt-4", # noqa: F722
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns OpenAI instance for 'gpt-4' model using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Model name to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-4',
"gpt-4-1106-preview", ....
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model

DOCS: https://spacy.io/api/large-language-models#models
"""
return OpenAI(
name=name,
endpoint=Endpoints.CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
)


@registry.llm_models("spacy.GPT-4.v2")
def openai_gpt_4_v2(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
Expand All @@ -35,7 +64,7 @@ def openai_gpt_4_v2(
"""Returns OpenAI instance for 'gpt-4' model using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (Optional[Literal["0314", "32k", "32k-0314"]]): Model to use. Base 'gpt-4' model by default.
name (Literal["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"]): Model to use. Base 'gpt-4' model by default.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model

DOCS: https://spacy.io/api/large-language-models#models
Expand Down Expand Up @@ -65,7 +94,8 @@ def openai_gpt_4(
"""Returns OpenAI instance for 'gpt-4' model using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (Optional[Literal["0314", "32k", "32k-0314"]]): Model to use. Base 'gpt-4' model by default.
name (Literal["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"]): Model to use. Base 'gpt-4' model by
default.
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model

DOCS: https://spacy.io/api/large-language-models#models
Expand All @@ -81,6 +111,37 @@ def openai_gpt_4(
)


@registry.llm_models("spacy.GPT-3-5.v3")
def openai_gpt_3_5_v3(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
name: str = "gpt-3.5-turbo",
strict: bool = OpenAI.DEFAULT_STRICT,
max_tries: int = OpenAI.DEFAULT_MAX_TRIES,
interval: float = OpenAI.DEFAULT_INTERVAL,
max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME,
) -> Callable[[Iterable[str]], Iterable[str]]:
"""Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API.

config (Dict[Any, Any]): LLM config passed on to the model's initialization.
name (str): Name of model to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-3.5',
"gpt-3.5-turbo", ....
RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-3.5' model

DOCS: https://spacy.io/api/large-language-models#models
"""
return OpenAI(
name=name,
endpoint=Endpoints.CHAT.value
# gpt-3.5-turbo-instruct runs on the non-chat endpoint, so we use that one by default to allow batching.
if name != "gpt-3.5-turbo-instruct" else Endpoints.NON_CHAT.value,
config=config,
strict=strict,
max_tries=max_tries,
interval=interval,
max_request_time=max_request_time,
)


@registry.llm_models("spacy.GPT-3-5.v2")
def openai_gpt_3_5_v2(
config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE),
Expand Down
6 changes: 1 addition & 5 deletions spacy_llm/models/rest/palm/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import os
import warnings
from enum import Enum
from typing import Any, Dict, Iterable, List, Sized, Tuple
from typing import Any, Dict, Iterable, List, Sized

import requests # type: ignore[import]
import srsly # type: ignore[import]
Expand Down Expand Up @@ -107,7 +107,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
api_responses.append(srsly.json_dumps(response))

return api_responses

@classmethod
def get_model_names(cls) -> Tuple[str, ...]:
return "text-bison-001", "chat-bison-001"
2 changes: 1 addition & 1 deletion spacy_llm/tests/test_combinations.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
@pytest.mark.skipif(has_langchain is False, reason="LangChain is not installed")
@pytest.mark.parametrize(
"model",
["langchain.OpenAI.v1", "spacy.GPT-3-5.v1", "spacy.GPT-3-5.v2"],
["langchain.OpenAI.v1", "spacy.GPT-3-5.v3", "spacy.GPT-4.v3"],
ids=["langchain", "rest-openai", "rest-openai"],
)
@pytest.mark.parametrize(
Expand Down