diff --git a/spacy_llm/models/rest/anthropic/model.py b/spacy_llm/models/rest/anthropic/model.py index 602ba14b..774d0e83 100644 --- a/spacy_llm/models/rest/anthropic/model.py +++ b/spacy_llm/models/rest/anthropic/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -107,26 +107,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: assert len(api_responses) == len(prompts) return api_responses - - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return ( - # claude-2 - "claude-2", - "claude-2-100k", - # claude-1 - "claude-1", - "claude-1-100k", - # claude-instant-1 - "claude-instant-1", - "claude-instant-1-100k", - # claude-instant-1.1 - "claude-instant-1.1", - "claude-instant-1.1-100k", - # claude-1.3 - "claude-1.3", - "claude-1.3-100k", - # others - "claude-1.0", - "claude-1.2", - ) diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index a5000d9b..d8f433d6 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -147,12 +147,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: api_responses.append(response.get("text", srsly.json_dumps(response))) return api_responses - - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - # We treat the deployment name as "model name", hence it can be arbitrary. - return ("",) - - def _check_model(self) -> None: - # We treat the deployment name as "model name", hence it can be arbitrary. - pass diff --git a/spacy_llm/models/rest/base.py b/spacy_llm/models/rest/base.py index b7dccca3..f54f90ac 100644 --- a/spacy_llm/models/rest/base.py +++ b/spacy_llm/models/rest/base.py @@ -1,7 +1,7 @@ import abc import time from enum import Enum -from typing import Any, Callable, Dict, Iterable, Optional, Tuple +from typing import Any, Callable, Dict, Iterable, Optional import requests # type: ignore from requests import ConnectTimeout, ReadTimeout @@ -61,16 +61,8 @@ def __init__( assert self._interval > 0 assert self._max_request_time > 0 - self._check_model() self._verify_auth() - def _check_model(self) -> None: - """Checks whether model is supported. Raises if it isn't.""" - if self._name not in self.get_model_names(): - raise ValueError( - f"Model '{self._name}' is not supported - select one of {self.get_model_names()} instead" - ) - @abc.abstractmethod def __call__(self, prompts: Iterable[str]) -> Iterable[str]: """Executes prompts on specified API. @@ -78,13 +70,6 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: RETURNS (Iterable[str]): API responses. """ - @classmethod - @abc.abstractmethod - def get_model_names(cls) -> Tuple[str, ...]: - """Names of supported models. - RETURNS (Tuple[str]): Names of supported models. - """ - @property @abc.abstractmethod def credentials(self) -> Dict[str, str]: diff --git a/spacy_llm/models/rest/cohere/model.py b/spacy_llm/models/rest/cohere/model.py index 293ed92b..ecd60d5d 100644 --- a/spacy_llm/models/rest/cohere/model.py +++ b/spacy_llm/models/rest/cohere/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -111,7 +111,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: else: api_responses.append(srsly.json_dumps(response)) return api_responses - - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return "command", "command-light", "command-light-nightly", "command-nightly" diff --git a/spacy_llm/models/rest/noop/model.py b/spacy_llm/models/rest/noop/model.py index 0e3e0398..31d830b8 100644 --- a/spacy_llm/models/rest/noop/model.py +++ b/spacy_llm/models/rest/noop/model.py @@ -1,5 +1,5 @@ import time -from typing import Dict, Iterable, Tuple +from typing import Dict, Iterable from ..base import REST @@ -33,7 +33,3 @@ def __call__(self, prompts: Iterable[str]) -> Iterable[str]: # Assume time penalty for API calls. time.sleep(NoOpModel._CALL_TIMEOUT) return [_NOOP_RESPONSE] * len(list(prompts)) - - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return ("NoOp",) diff --git a/spacy_llm/models/rest/openai/model.py b/spacy_llm/models/rest/openai/model.py index a712e082..8fa9dc20 100644 --- a/spacy_llm/models/rest/openai/model.py +++ b/spacy_llm/models/rest/openai/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -140,31 +140,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: api_responses.append(srsly.json_dumps(response)) return api_responses - - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return ( - # gpt-4 - "gpt-4", - "gpt-4-0314", - "gpt-4-32k", - "gpt-4-32k-0314", - # gpt-3.5 - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-0613-16k", - "gpt-3.5-turbo-instruct", - # text-davinci - "text-davinci-002", - "text-davinci-003", - # others - "code-davinci-002", - "text-curie-001", - "text-babbage-001", - "text-ada-001", - "davinci", - "curie", - "babbage", - "ada", - ) diff --git a/spacy_llm/models/rest/openai/registry.py b/spacy_llm/models/rest/openai/registry.py index 1a4c4fd7..82436e4a 100644 --- a/spacy_llm/models/rest/openai/registry.py +++ b/spacy_llm/models/rest/openai/registry.py @@ -21,6 +21,35 @@ """ +@registry.llm_models("spacy.GPT-4.v3") +def openai_gpt_4_v3( + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + name: str = "gpt-4", # noqa: F722 + strict: bool = OpenAI.DEFAULT_STRICT, + max_tries: int = OpenAI.DEFAULT_MAX_TRIES, + interval: float = OpenAI.DEFAULT_INTERVAL, + max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, +) -> Callable[[Iterable[str]], Iterable[str]]: + """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Model name to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-4', + "gpt-4-1106-preview", .... + RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model + + DOCS: https://spacy.io/api/large-language-models#models + """ + return OpenAI( + name=name, + endpoint=Endpoints.CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + ) + + @registry.llm_models("spacy.GPT-4.v2") def openai_gpt_4_v2( config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), @@ -35,7 +64,7 @@ def openai_gpt_4_v2( """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. - name (Optional[Literal["0314", "32k", "32k-0314"]]): Model to use. Base 'gpt-4' model by default. + name (Literal["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"]): Model to use. Base 'gpt-4' model by default. RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model DOCS: https://spacy.io/api/large-language-models#models @@ -65,7 +94,8 @@ def openai_gpt_4( """Returns OpenAI instance for 'gpt-4' model using REST to prompt API. config (Dict[Any, Any]): LLM config passed on to the model's initialization. - name (Optional[Literal["0314", "32k", "32k-0314"]]): Model to use. Base 'gpt-4' model by default. + name (Literal["gpt-4", "gpt-4-0314", "gpt-4-32k", "gpt-4-32k-0314"]): Model to use. Base 'gpt-4' model by + default. RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-4' model DOCS: https://spacy.io/api/large-language-models#models @@ -81,6 +111,37 @@ def openai_gpt_4( ) +@registry.llm_models("spacy.GPT-3-5.v3") +def openai_gpt_3_5_v3( + config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), + name: str = "gpt-3.5-turbo", + strict: bool = OpenAI.DEFAULT_STRICT, + max_tries: int = OpenAI.DEFAULT_MAX_TRIES, + interval: float = OpenAI.DEFAULT_INTERVAL, + max_request_time: float = OpenAI.DEFAULT_MAX_REQUEST_TIME, +) -> Callable[[Iterable[str]], Iterable[str]]: + """Returns OpenAI instance for 'gpt-3.5' model using REST to prompt API. + + config (Dict[Any, Any]): LLM config passed on to the model's initialization. + name (str): Name of model to use. Can be any model name supported by the OpenAI API - e. g. 'gpt-3.5', + "gpt-3.5-turbo", .... + RETURNS (Callable[[Iterable[str]], Iterable[str]]]): OpenAI instance for 'gpt-3.5' model + + DOCS: https://spacy.io/api/large-language-models#models + """ + return OpenAI( + name=name, + endpoint=Endpoints.CHAT.value + # gpt-3.5-turbo-instruct runs on the non-chat endpoint, so we use that one by default to allow batching. + if name != "gpt-3.5-turbo-instruct" else Endpoints.NON_CHAT.value, + config=config, + strict=strict, + max_tries=max_tries, + interval=interval, + max_request_time=max_request_time, + ) + + @registry.llm_models("spacy.GPT-3-5.v2") def openai_gpt_3_5_v2( config: Dict[Any, Any] = SimpleFrozenDict(temperature=_DEFAULT_TEMPERATURE), diff --git a/spacy_llm/models/rest/palm/model.py b/spacy_llm/models/rest/palm/model.py index 1e9b10b1..1a488000 100644 --- a/spacy_llm/models/rest/palm/model.py +++ b/spacy_llm/models/rest/palm/model.py @@ -1,7 +1,7 @@ import os import warnings from enum import Enum -from typing import Any, Dict, Iterable, List, Sized, Tuple +from typing import Any, Dict, Iterable, List, Sized import requests # type: ignore[import] import srsly # type: ignore[import] @@ -107,7 +107,3 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: api_responses.append(srsly.json_dumps(response)) return api_responses - - @classmethod - def get_model_names(cls) -> Tuple[str, ...]: - return "text-bison-001", "chat-bison-001" diff --git a/spacy_llm/tests/test_combinations.py b/spacy_llm/tests/test_combinations.py index 0cd986c5..d36e72ac 100644 --- a/spacy_llm/tests/test_combinations.py +++ b/spacy_llm/tests/test_combinations.py @@ -12,7 +12,7 @@ @pytest.mark.skipif(has_langchain is False, reason="LangChain is not installed") @pytest.mark.parametrize( "model", - ["langchain.OpenAI.v1", "spacy.GPT-3-5.v1", "spacy.GPT-3-5.v2"], + ["langchain.OpenAI.v1", "spacy.GPT-3-5.v3", "spacy.GPT-4.v3"], ids=["langchain", "rest-openai", "rest-openai"], ) @pytest.mark.parametrize(