Skip to content

Commit

Permalink
Merge branch 'refactor/model-registry-by-provider' of github.com:expl…
Browse files Browse the repository at this point in the history
…osion/spacy-llm into refactor/model-registry-by-provider
  • Loading branch information
rmitsch committed Apr 20, 2024
2 parents d02bd41 + b1339de commit 5cd5c64
Show file tree
Hide file tree
Showing 11 changed files with 79 additions and 45 deletions.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ filterwarnings = [
"ignore:^.*`__get_validators__` is deprecated.*",
"ignore:^.*The `construct` method is deprecated.*",
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*"
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
"ignore:^.*was deprecated in langchain-community.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mypy>=0.990,<1.1.0; platform_machine != "aarch64" and python_version >= "3.7"
black==22.3.0
types-requests==2.28.11.16
# Prompting libraries needed for testing
langchain==0.0.331; python_version>="3.9"
langchain>=0.1,<0.2; python_version>="3.9"
# Workaround for LangChain bug: pin OpenAI version. To be removed after LangChain has been fixed - see
# https://github.com/langchain-ai/langchain/issues/12967.
openai>=0.27,<=0.28.1; python_version>="3.9"
Expand Down
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[metadata]
version = 0.7.0
version = 0.7.1
description = Integrating LLMs into structured NLP pipelines
author = Explosion
author_email = [email protected]
Expand Down Expand Up @@ -44,7 +44,7 @@ spacy_misc =

[options.extras_require]
langchain =
langchain==0.0.335
langchain>=0.1,<0.2
transformers =
torch>=1.13.1,<2.0
transformers>=4.28.1,<5.0
Expand Down
2 changes: 2 additions & 0 deletions spacy_llm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

try:
import langchain
import langchain_community

has_langchain = True
except (ImportError, AttributeError):
langchain = None
langchain_community = None
has_langchain = False

try:
Expand Down
35 changes: 19 additions & 16 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from confection import SimpleFrozenDict

from ...compat import ExtraError, ValidationError, has_langchain, langchain
from ...compat import ExtraError, ValidationError, has_langchain, langchain_community
from ...registry import registry

try:
from langchain import llms # noqa: F401
from langchain_community import llms # noqa: F401
except (ImportError, AttributeError):
llms = None

Expand All @@ -18,16 +18,17 @@ def __init__(
api: str,
config: Dict[Any, Any],
query: Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
],
context_length: Optional[int],
):
"""Initializes model instance for integration APIs.
name (str): Name of LangChain model to instantiate.
api (str): Name of class/API.
config (Dict[Any, Any]): Config passed on to LangChain model.
query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
LLM prompts when supplied with the model instance.
query (Callable[[langchain_community.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
executing LLM prompts when supplied with the model instance.
context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context
length provided, prompts can't be sharded.
"""
Expand All @@ -39,7 +40,7 @@ def __init__(
@classmethod
def _init_langchain_model(
cls, name: str, api: str, config: Dict[Any, Any]
) -> "langchain.llms.BaseLLM":
) -> "langchain_community.llms.BaseLLM":
"""Initializes langchain model. langchain expects a range of different model ID argument names, depending on the
model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through
them.
Expand Down Expand Up @@ -73,12 +74,13 @@ def _init_langchain_model(
raise err

@staticmethod
def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]:
"""Returns langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict.
def get_type_to_cls_dict() -> Dict[str, Type["langchain_community.llms.BaseLLM"]]:
"""Returns langchain_community.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain_community.llms.BaseLLM]]): langchain_community.llms.type_to_cls_dict.
"""
return {
llm_id: getattr(langchain.llms, llm_id) for llm_id in langchain.llms.__all__
llm_id: getattr(langchain_community.llms, llm_id)
for llm_id in langchain_community.llms.__all__
}

def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
Expand All @@ -90,10 +92,10 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:

@staticmethod
def query_langchain(
model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
model: "langchain_community.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
) -> Iterable[Iterable[Any]]:
"""Query LangChain model naively.
model (langchain.llms.BaseLLM): LangChain model.
model (langchain_community.llms.BaseLLM): LangChain model.
prompts (Iterable[Iterable[Any]]): Prompts to execute.
RETURNS (Iterable[Iterable[Any]]): LLM responses.
"""
Expand All @@ -117,7 +119,7 @@ def langchain_model(
name: str,
query: Optional[
Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[str]]],
["langchain_community.llms.BaseLLM", Iterable[Iterable[str]]],
Iterable[Iterable[str]],
]
] = None,
Expand Down Expand Up @@ -172,11 +174,12 @@ def register_models() -> None:
@registry.llm_queries("spacy.CallLangChain.v1")
def query_langchain() -> (
Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
]
):
"""Returns query Callable for LangChain.
RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
simple prompts on the specified LangChain model.
RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
executing simple prompts on the specified LangChain model.
"""
return LangChain.query_langchain
2 changes: 1 addition & 1 deletion spacy_llm/models/rest/azure/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(
self._deployment_name = deployment_name
super().__init__(
name=name,
endpoint=endpoint or endpoint,
endpoint=endpoint,
config=config,
strict=strict,
max_tries=max_tries,
Expand Down
36 changes: 15 additions & 21 deletions spacy_llm/models/rest/openai/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,6 @@ def credentials(self) -> Dict[str, str]:
if api_org:
headers["OpenAI-Organization"] = api_org

# Ensure endpoint is supported.
if self._endpoint not in (Endpoints.NON_CHAT, Endpoints.CHAT):
raise ValueError(
f"Endpoint {self._endpoint} isn't supported. Please use one of: {Endpoints.CHAT}, {Endpoints.NON_CHAT}."
)

return headers

def _verify_auth(self) -> None:
Expand Down Expand Up @@ -115,9 +109,21 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:

return responses

if self._endpoint == Endpoints.CHAT:
# The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual
# requests.
# The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual requests.

if self._endpoint == Endpoints.NON_CHAT:
responses = _request({"prompt": prompts_for_doc})
if "error" in responses:
return responses["error"]
assert len(responses["choices"]) == len(prompts_for_doc)

for response in responses["choices"]:
if "text" in response:
api_responses.append(response["text"])
else:
api_responses.append(srsly.json_dumps(response))

else:
for prompt in prompts_for_doc:
responses = _request(
{"messages": [{"role": "user", "content": prompt}]}
Expand All @@ -134,18 +140,6 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]:
)
)

elif self._endpoint == Endpoints.NON_CHAT:
responses = _request({"prompt": prompts_for_doc})
if "error" in responses:
return responses["error"]
assert len(responses["choices"]) == len(prompts_for_doc)

for response in responses["choices"]:
if "text" in response:
api_responses.append(response["text"])
else:
api_responses.append(srsly.json_dumps(response))

all_api_responses.append(api_responses)

return all_api_responses
Expand Down
7 changes: 5 additions & 2 deletions spacy_llm/tasks/entity_linker/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,6 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]:
self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)]
self._n_shards = None

return [
EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i])
for i, doc in enumerate(docs)
Expand Down Expand Up @@ -335,7 +334,11 @@ def unhighlight_ents_in_doc(doc: Doc) -> Doc:
for ent in doc.ents
if ent.start - 1 > 0 and doc[ent.start - 1].text == "*"
}
highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"}
highlight_end_idx = {
ent.end
for ent in doc.ents
if ent.end < len(doc) and doc[ent.end].text == "*"
}
highlight_idx = highlight_start_idx | highlight_end_idx

# Compute entity indices with removed highlights.
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/tests/tasks/legacy/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -832,6 +832,7 @@ def test_ner_to_disk(noop_config, tmp_path: Path):
assert task1._label_dict == task2._label_dict == labels


@pytest.mark.filterwarnings("ignore:Task supports sharding")
def test_label_inconsistency():
"""Test whether inconsistency between specified labels and labels in examples is detected."""
cfg = f"""
Expand Down
31 changes: 30 additions & 1 deletion spacy_llm/tests/tasks/test_entity_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,9 +685,38 @@ def test_ent_highlighting():
EntityLinkerTask.highlight_ents_in_doc(doc).text
== "Alice goes to *Boston* to see the *Boston Celtics* game."
)


@pytest.mark.parametrize(
"text,ents,include_ents",
[
(
"Alice goes to Boston to see the Boston Celtics game.",
[
{"start": 3, "end": 4, "label": "LOC"},
{"start": 7, "end": 9, "label": "ORG"},
],
[True, True],
),
(
"I went to see Boston in concert yesterday",
[
{"start": 4, "end": 5, "label": "GPE"},
{"start": 7, "end": 8, "label": "DATE"},
],
[True, False],
),
],
)
def test_ent_unhighlighting(text, ents, include_ents):
"""Tests unhighlighting of entities in text."""
nlp = spacy.blank("en")
doc = nlp.make_doc(text)
doc.ents = [Span(doc=doc, **ents[0]), Span(doc=doc, **ents[1])]

assert (
EntityLinkerTask.unhighlight_ents_in_doc(
EntityLinkerTask.highlight_ents_in_doc(doc)
EntityLinkerTask.highlight_ents_in_doc(doc, include_ents)
).text
== doc.text
== text
Expand Down
1 change: 1 addition & 0 deletions spacy_llm/tests/tasks/test_ner.py
Original file line number Diff line number Diff line change
Expand Up @@ -824,6 +824,7 @@ def test_ner_to_disk(noop_config: str, tmp_path: Path):
assert task1._label_dict == task2._label_dict == labels


@pytest.mark.filterwarnings("ignore:Task supports sharding")
def test_label_inconsistency():
"""Test whether inconsistency between specified labels and labels in examples is detected."""
cfg = f"""
Expand Down

0 comments on commit 5cd5c64

Please sign in to comment.