From 09fcad1a6ebdb737e3daece08df19f54c1dcd531 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 26 Jan 2024 15:43:33 +0100 Subject: [PATCH 1/5] Remove check for fixed endpoint for OpenAI models. (#429) --- spacy_llm/models/rest/azure/model.py | 2 +- spacy_llm/models/rest/openai/model.py | 36 +++++++++++---------------- 2 files changed, 16 insertions(+), 22 deletions(-) diff --git a/spacy_llm/models/rest/azure/model.py b/spacy_llm/models/rest/azure/model.py index 5a2d0fef..32adc0bb 100644 --- a/spacy_llm/models/rest/azure/model.py +++ b/spacy_llm/models/rest/azure/model.py @@ -35,7 +35,7 @@ def __init__( self._deployment_name = deployment_name super().__init__( name=name, - endpoint=endpoint or endpoint, + endpoint=endpoint, config=config, strict=strict, max_tries=max_tries, diff --git a/spacy_llm/models/rest/openai/model.py b/spacy_llm/models/rest/openai/model.py index b8bbdae3..7715f12c 100644 --- a/spacy_llm/models/rest/openai/model.py +++ b/spacy_llm/models/rest/openai/model.py @@ -36,12 +36,6 @@ def credentials(self) -> Dict[str, str]: if api_org: headers["OpenAI-Organization"] = api_org - # Ensure endpoint is supported. - if self._endpoint not in (Endpoints.NON_CHAT, Endpoints.CHAT): - raise ValueError( - f"Endpoint {self._endpoint} isn't supported. Please use one of: {Endpoints.CHAT}, {Endpoints.NON_CHAT}." - ) - return headers def _verify_auth(self) -> None: @@ -115,9 +109,21 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: return responses - if self._endpoint == Endpoints.CHAT: - # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual - # requests. + # The OpenAI API doesn't support batching for /chat/completions yet, so we have to send individual requests. + + if self._endpoint == Endpoints.NON_CHAT: + responses = _request({"prompt": prompts_for_doc}) + if "error" in responses: + return responses["error"] + assert len(responses["choices"]) == len(prompts_for_doc) + + for response in responses["choices"]: + if "text" in response: + api_responses.append(response["text"]) + else: + api_responses.append(srsly.json_dumps(response)) + + else: for prompt in prompts_for_doc: responses = _request( {"messages": [{"role": "user", "content": prompt}]} @@ -134,18 +140,6 @@ def _request(json_data: Dict[str, Any]) -> Dict[str, Any]: ) ) - elif self._endpoint == Endpoints.NON_CHAT: - responses = _request({"prompt": prompts_for_doc}) - if "error" in responses: - return responses["error"] - assert len(responses["choices"]) == len(prompts_for_doc) - - for response in responses["choices"]: - if "text" in response: - api_responses.append(response["text"]) - else: - api_responses.append(srsly.json_dumps(response)) - all_api_responses.append(api_responses) return all_api_responses From 0fc46336ab63d02b77e347b766ec021d07347bfb Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 29 Jan 2024 10:02:07 +0100 Subject: [PATCH 2/5] Fix legacy NER test warning in CI (#430) * Fix legacy NER test warning in CI. * Add warning filter for NER inconsistency test. --- spacy_llm/tests/tasks/legacy/test_ner.py | 1 + spacy_llm/tests/tasks/test_ner.py | 1 + 2 files changed, 2 insertions(+) diff --git a/spacy_llm/tests/tasks/legacy/test_ner.py b/spacy_llm/tests/tasks/legacy/test_ner.py index 3d9c133a..551e3dba 100644 --- a/spacy_llm/tests/tasks/legacy/test_ner.py +++ b/spacy_llm/tests/tasks/legacy/test_ner.py @@ -832,6 +832,7 @@ def test_ner_to_disk(noop_config, tmp_path: Path): assert task1._label_dict == task2._label_dict == labels +@pytest.mark.filterwarnings("ignore:Task supports sharding") def test_label_inconsistency(): """Test whether inconsistency between specified labels and labels in examples is detected.""" cfg = f""" diff --git a/spacy_llm/tests/tasks/test_ner.py b/spacy_llm/tests/tasks/test_ner.py index e8782d08..5fe4b178 100644 --- a/spacy_llm/tests/tasks/test_ner.py +++ b/spacy_llm/tests/tasks/test_ner.py @@ -820,6 +820,7 @@ def test_ner_to_disk(noop_config: str, tmp_path: Path): assert task1._label_dict == task2._label_dict == labels +@pytest.mark.filterwarnings("ignore:Task supports sharding") def test_label_inconsistency(): """Test whether inconsistency between specified labels and labels in examples is detected.""" cfg = f""" From 2e88594eaae04105e75fe6e280c732464963a65d Mon Sep 17 00:00:00 2001 From: Magdalena Aniol <96200718+magdaaniol@users.noreply.github.com> Date: Mon, 29 Jan 2024 10:38:15 +0100 Subject: [PATCH 3/5] Fix/index error when unhighlighting (#434) * make sure nonhighlighted ents don't cause IndexError when unhighlighting * linting --- spacy_llm/tasks/entity_linker/task.py | 7 +++-- spacy_llm/tests/tasks/test_entity_linker.py | 31 ++++++++++++++++++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/spacy_llm/tasks/entity_linker/task.py b/spacy_llm/tasks/entity_linker/task.py index 86426ed0..fd44506d 100644 --- a/spacy_llm/tasks/entity_linker/task.py +++ b/spacy_llm/tasks/entity_linker/task.py @@ -105,7 +105,6 @@ def _preprocess_docs_for_prompt(self, docs: Iterable[Doc]) -> Iterable[Doc]: self._ents_cands_by_shard = [[] * len(self._ents_cands_by_doc)] self._has_ent_cands_by_shard = [[] * len(self._ents_cands_by_doc)] self._n_shards = None - return [ EntityLinkerTask.highlight_ents_in_doc(doc, self._has_ent_cands_by_doc[i]) for i, doc in enumerate(docs) @@ -335,7 +334,11 @@ def unhighlight_ents_in_doc(doc: Doc) -> Doc: for ent in doc.ents if ent.start - 1 > 0 and doc[ent.start - 1].text == "*" } - highlight_end_idx = {ent.end for ent in doc.ents if doc[ent.end].text == "*"} + highlight_end_idx = { + ent.end + for ent in doc.ents + if ent.end < len(doc) and doc[ent.end].text == "*" + } highlight_idx = highlight_start_idx | highlight_end_idx # Compute entity indices with removed highlights. diff --git a/spacy_llm/tests/tasks/test_entity_linker.py b/spacy_llm/tests/tasks/test_entity_linker.py index 6101236b..45f18e7e 100644 --- a/spacy_llm/tests/tasks/test_entity_linker.py +++ b/spacy_llm/tests/tasks/test_entity_linker.py @@ -682,9 +682,38 @@ def test_ent_highlighting(): EntityLinkerTask.highlight_ents_in_doc(doc).text == "Alice goes to *Boston* to see the *Boston Celtics* game." ) + + +@pytest.mark.parametrize( + "text,ents,include_ents", + [ + ( + "Alice goes to Boston to see the Boston Celtics game.", + [ + {"start": 3, "end": 4, "label": "LOC"}, + {"start": 7, "end": 9, "label": "ORG"}, + ], + [True, True], + ), + ( + "I went to see Boston in concert yesterday", + [ + {"start": 4, "end": 5, "label": "GPE"}, + {"start": 7, "end": 8, "label": "DATE"}, + ], + [True, False], + ), + ], +) +def test_ent_unhighlighting(text, ents, include_ents): + """Tests unhighlighting of entities in text.""" + nlp = spacy.blank("en") + doc = nlp.make_doc(text) + doc.ents = [Span(doc=doc, **ents[0]), Span(doc=doc, **ents[1])] + assert ( EntityLinkerTask.unhighlight_ents_in_doc( - EntityLinkerTask.highlight_ents_in_doc(doc) + EntityLinkerTask.highlight_ents_in_doc(doc, include_ents) ).text == doc.text == text From 789ee8bda6b8f6d6ed10d5dd6a271fcd9b4f6a8e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 29 Jan 2024 10:56:07 +0100 Subject: [PATCH 4/5] Adjust LangChain usage in line with `langchain` >= 0.1 (#433) * Update langchain pin. Adjust langchain model invocation in line with >= 0.1 changes. * Import from langchain_community instead from langchain. * Ignore langchain_community deprecation warning. --- pyproject.toml | 3 ++- requirements-dev.txt | 2 +- setup.cfg | 2 +- spacy_llm/compat.py | 2 ++ spacy_llm/models/langchain/model.py | 40 ++++++++++++++++------------- 5 files changed, 28 insertions(+), 21 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ba77808..d138c29a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ filterwarnings = [ "ignore:^.*`__get_validators__` is deprecated.*", "ignore:^.*The `construct` method is deprecated.*", "ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*", - "ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*" + "ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*", + "ignore:^.*was deprecated in langchain-community.*" ] markers = [ "external: interacts with a (potentially cost-incurring) third-party API", diff --git a/requirements-dev.txt b/requirements-dev.txt index 49239d6e..63862a4a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ mypy>=0.990,<1.1.0; platform_machine != "aarch64" and python_version >= "3.7" black==22.3.0 types-requests==2.28.11.16 # Prompting libraries needed for testing -langchain==0.0.331; python_version>="3.9" +langchain>=0.1,<0.2; python_version>="3.9" # Workaround for LangChain bug: pin OpenAI version. To be removed after LangChain has been fixed - see # https://github.com/langchain-ai/langchain/issues/12967. openai>=0.27,<=0.28.1; python_version>="3.9" diff --git a/setup.cfg b/setup.cfg index dc72faea..d2ab0673 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ spacy_misc = [options.extras_require] langchain = - langchain==0.0.335 + langchain>=0.1,<0.2 transformers = torch>=1.13.1,<2.0 transformers>=4.28.1,<5.0 diff --git a/spacy_llm/compat.py b/spacy_llm/compat.py index 85f39036..65f32b07 100644 --- a/spacy_llm/compat.py +++ b/spacy_llm/compat.py @@ -19,10 +19,12 @@ try: import langchain + import langchain_community has_langchain = True except (ImportError, AttributeError): langchain = None + langchain_community = None has_langchain = False try: diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 45da9ae6..cded9fd7 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -2,11 +2,11 @@ from confection import SimpleFrozenDict -from ...compat import ExtraError, ValidationError, has_langchain, langchain +from ...compat import ExtraError, ValidationError, has_langchain, langchain_community from ...registry import registry try: - from langchain import llms # noqa: F401 + from langchain_community import llms # noqa: F401 except (ImportError, AttributeError): llms = None @@ -18,7 +18,8 @@ def __init__( api: str, config: Dict[Any, Any], query: Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ], context_length: Optional[int], ): @@ -26,8 +27,8 @@ def __init__( name (str): Name of LangChain model to instantiate. api (str): Name of class/API. config (Dict[Any, Any]): Config passed on to LangChain model. - query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing - LLM prompts when supplied with the model instance. + query (Callable[[langchain_community.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable + executing LLM prompts when supplied with the model instance. context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context length provided, prompts can't be sharded. """ @@ -39,7 +40,7 @@ def __init__( @classmethod def _init_langchain_model( cls, name: str, api: str, config: Dict[Any, Any] - ) -> "langchain.llms.BaseLLM": + ) -> "langchain_community.llms.BaseLLM": """Initializes langchain model. langchain expects a range of different model ID argument names, depending on the model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through them. @@ -73,12 +74,13 @@ def _init_langchain_model( raise err @staticmethod - def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]: - """Returns langchain.llms.type_to_cls_dict. - RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict. + def get_type_to_cls_dict() -> Dict[str, Type["langchain_community.llms.BaseLLM"]]: + """Returns langchain_community.llms.type_to_cls_dict. + RETURNS (Dict[str, Type[langchain_community.llms.BaseLLM]]): langchain_community.llms.type_to_cls_dict. """ return { - llm_id: getattr(langchain.llms, llm_id) for llm_id in langchain.llms.__all__ + llm_id: getattr(langchain_community.llms, llm_id) + for llm_id in langchain_community.llms.__all__ } def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: @@ -90,15 +92,16 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: @staticmethod def query_langchain( - model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]] + model: "langchain_community.llms.BaseLLM", prompts: Iterable[Iterable[Any]] ) -> Iterable[Iterable[Any]]: """Query LangChain model naively. - model (langchain.llms.BaseLLM): LangChain model. + model (langchain_community.llms.BaseLLM): LangChain model. prompts (Iterable[Iterable[Any]]): Prompts to execute. RETURNS (Iterable[Iterable[Any]]): LLM responses. """ - assert callable(model) - return [[model(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts] + return [ + [model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts + ] @staticmethod def _check_installation() -> None: @@ -115,7 +118,7 @@ def langchain_model( name: str, query: Optional[ Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[str]]], + ["langchain_community.llms.BaseLLM", Iterable[Iterable[str]]], Iterable[Iterable[str]], ] ] = None, @@ -170,11 +173,12 @@ def register_models() -> None: @registry.llm_queries("spacy.CallLangChain.v1") def query_langchain() -> ( Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ] ): """Returns query Callable for LangChain. - RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing - simple prompts on the specified LangChain model. + RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable + executing simple prompts on the specified LangChain model. """ return LangChain.query_langchain From c87d5a6373485c7510d92b5fed08770f82c96c1a Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Mon, 29 Jan 2024 14:36:19 +0100 Subject: [PATCH 5/5] Bump version. (#435) --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index d2ab0673..ebe4d44a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,5 +1,5 @@ [metadata] -version = 0.7.0 +version = 0.7.1 description = Integrating LLMs into structured NLP pipelines author = Explosion author_email = contact@explosion.ai