From 9007eeb3cbea684636f08634ecc765db71a3ea96 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 17 Nov 2023 10:12:06 +0100 Subject: [PATCH 1/7] Add workaround for langchain model ID issue. --- spacy_llm/models/langchain/model.py | 47 +++++++++++++----------- spacy_llm/tests/models/test_langchain.py | 11 +++++- 2 files changed, 35 insertions(+), 23 deletions(-) diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 03657cdf..5a9e35f7 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -1,12 +1,12 @@ from typing import Any, Callable, Dict, Iterable, Optional, Type from confection import SimpleFrozenDict +from pydantic import ValidationError from ...compat import has_langchain, langchain from ...registry import registry try: - from langchain import base_language # noqa: F401 from langchain import llms # noqa: F401 except (ImportError, AttributeError): llms = None @@ -18,30 +18,38 @@ def __init__( name: str, api: str, config: Dict[Any, Any], - query: Callable[ - ["langchain.base_language.BaseLanguageModel", Iterable[Any]], - Iterable[Any], - ], + query: Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]], ): """Initializes model instance for integration APIs. name (str): Name of LangChain model to instantiate. api (str): Name of class/API. config (Dict[Any, Any]): Config passed on to LangChain model. - query (Callable[[Any, Iterable[_PromptType]], Iterable[_ResponseType]]): Callable executing LLM prompts when + query (Callable[[langchain.llms.BaseLLM, Iterable[Any]], Iterable[Any]]): Callable executing LLM prompts when supplied with the `integration` object. """ - self._langchain_model = LangChain.get_type_to_cls_dict()[api]( - model_name=name, **config - ) + # LangChain expects a range of different model ID argument names. There doesn't seem to be a clean way to + # determine those from the outset, we'll fail our way through them. + model_init_args = ["model", "model_name", "model_id"] + for model_init_arg in model_init_args: + try: + self._langchain_model = LangChain.get_type_to_cls_dict()[api]( + **{model_init_arg: name}, **config + ) + break + except (ValidationError, UserWarning) as err: + if model_init_arg == model_init_args[-1]: + raise ValueError( + "Couldn't initialize LangChain model with known model ID arguments. Please report this to " + "https://github.com/explosion/spacy-llm/issues. Thank you!" + ) from err + self.query = query self._check_installation() @staticmethod - def get_type_to_cls_dict() -> Dict[ - str, Type["langchain.base_language.BaseLanguageModel"] - ]: + def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]: """Returns langchain.llms.type_to_cls_dict. - RETURNS (Dict[str, Type[langchain.base_language.BaseLanguageModel]]): langchain.llms.type_to_cls_dict. + RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict. """ return getattr(langchain.llms, "type_to_cls_dict") @@ -54,10 +62,10 @@ def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]: @staticmethod def query_langchain( - model: "langchain.base_language.BaseLanguageModel", prompts: Iterable[Any] + model: "langchain.llms.BaseLLM", prompts: Iterable[Any] ) -> Iterable[Any]: """Query LangChain model naively. - model (langchain.base_language.BaseLanguageModel): LangChain model. + model (langchain.llms.BaseLLM): LangChain model. prompts (Iterable[Any]): Prompts to execute. RETURNS (Iterable[Any]): LLM responses. """ @@ -77,10 +85,7 @@ def _langchain_model_maker(class_id: str): def langchain_model( name: str, query: Optional[ - Callable[ - ["langchain.base_language.BaseLanguageModel", Iterable[str]], - Iterable[str], - ] + Callable[["langchain.llms.BaseLLM", Iterable[str]], Iterable[str]] ] = None, config: Dict[Any, Any] = SimpleFrozenDict(), langchain_class_id: str = class_id, @@ -123,9 +128,7 @@ def register_models() -> None: @registry.llm_queries("spacy.CallLangChain.v1") def query_langchain() -> ( - Callable[ - ["langchain.base_language.BaseLanguageModel", Iterable[Any]], Iterable[Any] - ] + Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]] ): """Returns query Callable for LangChain. RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]:): Callable executing simple prompts on diff --git a/spacy_llm/tests/models/test_langchain.py b/spacy_llm/tests/models/test_langchain.py index fd48e0bb..57e984dc 100644 --- a/spacy_llm/tests/models/test_langchain.py +++ b/spacy_llm/tests/models/test_langchain.py @@ -1,9 +1,11 @@ import os +from typing import List import pytest import spacy from spacy_llm.compat import has_langchain +from spacy_llm.models.langchain import LangChain from spacy_llm.tests.compat import has_azure_openai_key PIPE_CFG = { @@ -11,12 +13,19 @@ "@llm_models": "langchain.OpenAI.v1", "name": "ada", "config": {"temperature": 0.3}, - "query": {"@llm_queries": "spacy.CallLangChain.v1"}, }, "task": {"@llm_tasks": "spacy.NoOp.v1"}, } +def langchain_model_reg_handles() -> List[str]: + """Returns a list of all LangChain model reg handles.""" + return [ + f"langchain.{cls.__name__}.v1" + for class_id, cls in LangChain.get_type_to_cls_dict().items() + ] + + @pytest.mark.external @pytest.mark.skipif(has_langchain is False, reason="LangChain is not installed") def test_initialization(): From db50162fd670c45fe56bba67c1f5d61d6f175afd Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 17 Nov 2023 10:36:57 +0100 Subject: [PATCH 2/7] Refactor. --- spacy_llm/models/langchain/model.py | 47 +++++++++++++++++++++-------- 1 file changed, 34 insertions(+), 13 deletions(-) diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 5a9e35f7..11c40f7e 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -1,7 +1,7 @@ from typing import Any, Callable, Dict, Iterable, Optional, Type from confection import SimpleFrozenDict -from pydantic import ValidationError +from pydantic import ExtraError, ValidationError from ...compat import has_langchain, langchain from ...registry import registry @@ -27,24 +27,45 @@ def __init__( query (Callable[[langchain.llms.BaseLLM, Iterable[Any]], Iterable[Any]]): Callable executing LLM prompts when supplied with the `integration` object. """ - # LangChain expects a range of different model ID argument names. There doesn't seem to be a clean way to - # determine those from the outset, we'll fail our way through them. + self._langchain_model = LangChain._init_langchain_model(name, api, config) + self.query = query + self._check_installation() + + @classmethod + def _init_langchain_model( + cls, name: str, api: str, config: Dict[Any, Any] + ) -> "langchain.llms.BaseLLM": + """Initializes langchain model. langchain expects a range of different model ID argument names, depending on the + model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through + them. + Includes error checks for model ID arguments. + name (str): Name of LangChain model to instantiate. + api (str): Name of class/API. + config (Dict[Any, Any]): Config passed on to LangChain model. + """ model_init_args = ["model", "model_name", "model_id"] for model_init_arg in model_init_args: try: - self._langchain_model = LangChain.get_type_to_cls_dict()[api]( + return cls.get_type_to_cls_dict()[api]( **{model_init_arg: name}, **config ) - break - except (ValidationError, UserWarning) as err: + except ValidationError as err: if model_init_arg == model_init_args[-1]: - raise ValueError( - "Couldn't initialize LangChain model with known model ID arguments. Please report this to " - "https://github.com/explosion/spacy-llm/issues. Thank you!" - ) from err - - self.query = query - self._check_installation() + # If init error indicates that model ID arg is extraneous: raise error with hint on how to proceed. + if any( + [ + rerr + for rerr in err.raw_errors + if isinstance(rerr.exc, ExtraError) + and model_init_arg in rerr.loc_tuple() + ] + ): + raise ValueError( + "Couldn't initialize LangChain model with known model ID arguments. Please report this to " + "https://github.com/explosion/spacy-llm/issues. Thank you!" + ) from err + # Otherwise: raise error as-is. + raise err @staticmethod def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]: From dc41edf8db9621ca6803e505614e567ad4a58f63 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 17 Nov 2023 10:45:36 +0100 Subject: [PATCH 3/7] Extend filterwarnings. --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 54def2f1..36597a13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,8 @@ filterwarnings = [ "ignore:^.*The `dict` method is deprecated; use `model_dump` instead.*", "ignore:^.*The `parse_obj` method is deprecated; use `model_validate` instead.*", "ignore:^.*`__get_validators__` is deprecated.*", - "ignore:^.*The `construct` method is deprecated.*" + "ignore:^.*The `construct` method is deprecated.*", + "ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*" ] markers = [ "external: interacts with a (potentially cost-incurring) third-party API", From 7fb4d68b6290e8a4532488b504817a7191816448 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 17 Nov 2023 10:47:58 +0100 Subject: [PATCH 4/7] Revert filterwarnings. --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 36597a13..9a66987b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,6 @@ filterwarnings = [ "ignore:^.*The `parse_obj` method is deprecated; use `model_validate` instead.*", "ignore:^.*`__get_validators__` is deprecated.*", "ignore:^.*The `construct` method is deprecated.*", - "ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*" ] markers = [ "external: interacts with a (potentially cost-incurring) third-party API", From a14eca8314a2f2c1b13981f5114d29cf6e98931e Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 17 Nov 2023 10:50:15 +0100 Subject: [PATCH 5/7] Extend filterwarnings. --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 9a66987b..1ba77808 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,6 +25,8 @@ filterwarnings = [ "ignore:^.*The `parse_obj` method is deprecated; use `model_validate` instead.*", "ignore:^.*`__get_validators__` is deprecated.*", "ignore:^.*The `construct` method is deprecated.*", + "ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*", + "ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*" ] markers = [ "external: interacts with a (potentially cost-incurring) third-party API", From 3f2acb857e2ae13b03cb408038b9fb65bccdeeb5 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 17 Nov 2023 10:53:19 +0100 Subject: [PATCH 6/7] Fix pydantic imports. --- pyproject.toml | 1 - spacy_llm/compat.py | 5 +++-- spacy_llm/models/langchain/model.py | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1ba77808..9bba69f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,6 @@ filterwarnings = [ "ignore:^.*`__get_validators__` is deprecated.*", "ignore:^.*The `construct` method is deprecated.*", "ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*", - "ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*" ] markers = [ "external: interacts with a (potentially cost-incurring) third-party API", diff --git a/spacy_llm/compat.py b/spacy_llm/compat.py index 76148b77..1554622f 100644 --- a/spacy_llm/compat.py +++ b/spacy_llm/compat.py @@ -55,6 +55,7 @@ PYDANTIC_V2 = VERSION.startswith("2.") if PYDANTIC_V2: - from pydantic.v1 import BaseModel, ValidationError, validator # noqa: F401 + from pydantic.v1 import BaseModel, ExtraError, ValidationError # noqa: F401 + from pydantic.v1 import validator else: - from pydantic import BaseModel, ValidationError, validator # noqa: F401 + from pydantic import BaseModel, ExtraError, ValidationError, validator # noqa: F401 diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 11c40f7e..2e4be55f 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -1,9 +1,8 @@ from typing import Any, Callable, Dict, Iterable, Optional, Type from confection import SimpleFrozenDict -from pydantic import ExtraError, ValidationError -from ...compat import has_langchain, langchain +from ...compat import ExtraError, ValidationError, has_langchain, langchain from ...registry import registry try: From ec373c18f0fd2e3f1cc42f9b46a686afba3c15d7 Mon Sep 17 00:00:00 2001 From: Raphael Mitsch Date: Fri, 17 Nov 2023 10:57:52 +0100 Subject: [PATCH 7/7] Extend filterwarnings. --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 9bba69f7..1ba77808 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ filterwarnings = [ "ignore:^.*`__get_validators__` is deprecated.*", "ignore:^.*The `construct` method is deprecated.*", "ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*", + "ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*" ] markers = [ "external: interacts with a (potentially cost-incurring) third-party API",