diff --git a/pyproject.toml b/pyproject.toml index 1ba77808..d138c29a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,8 @@ filterwarnings = [ "ignore:^.*`__get_validators__` is deprecated.*", "ignore:^.*The `construct` method is deprecated.*", "ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*", - "ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*" + "ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*", + "ignore:^.*was deprecated in langchain-community.*" ] markers = [ "external: interacts with a (potentially cost-incurring) third-party API", diff --git a/requirements-dev.txt b/requirements-dev.txt index 49239d6e..63862a4a 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -7,7 +7,7 @@ mypy>=0.990,<1.1.0; platform_machine != "aarch64" and python_version >= "3.7" black==22.3.0 types-requests==2.28.11.16 # Prompting libraries needed for testing -langchain==0.0.331; python_version>="3.9" +langchain>=0.1,<0.2; python_version>="3.9" # Workaround for LangChain bug: pin OpenAI version. To be removed after LangChain has been fixed - see # https://github.com/langchain-ai/langchain/issues/12967. openai>=0.27,<=0.28.1; python_version>="3.9" diff --git a/setup.cfg b/setup.cfg index dc72faea..d2ab0673 100644 --- a/setup.cfg +++ b/setup.cfg @@ -44,7 +44,7 @@ spacy_misc = [options.extras_require] langchain = - langchain==0.0.335 + langchain>=0.1,<0.2 transformers = torch>=1.13.1,<2.0 transformers>=4.28.1,<5.0 diff --git a/spacy_llm/compat.py b/spacy_llm/compat.py index 85f39036..65f32b07 100644 --- a/spacy_llm/compat.py +++ b/spacy_llm/compat.py @@ -19,10 +19,12 @@ try: import langchain + import langchain_community has_langchain = True except (ImportError, AttributeError): langchain = None + langchain_community = None has_langchain = False try: diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 45da9ae6..cded9fd7 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -2,11 +2,11 @@ from confection import SimpleFrozenDict -from ...compat import ExtraError, ValidationError, has_langchain, langchain +from ...compat import ExtraError, ValidationError, has_langchain, langchain_community from ...registry import registry try: - from langchain import llms # noqa: F401 + from langchain_community import llms # noqa: F401 except (ImportError, AttributeError): llms = None @@ -18,7 +18,8 @@ def __init__( api: str, config: Dict[Any, Any], query: Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ], context_length: Optional[int], ): @@ -26,8 +27,8 @@ def __init__( name (str): Name of LangChain model to instantiate. api (str): Name of class/API. config (Dict[Any, Any]): Config passed on to LangChain model. - query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing - LLM prompts when supplied with the model instance. + query (Callable[[langchain_community.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable + executing LLM prompts when supplied with the model instance. context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context length provided, prompts can't be sharded. """ @@ -39,7 +40,7 @@ def __init__( @classmethod def _init_langchain_model( cls, name: str, api: str, config: Dict[Any, Any] - ) -> "langchain.llms.BaseLLM": + ) -> "langchain_community.llms.BaseLLM": """Initializes langchain model. langchain expects a range of different model ID argument names, depending on the model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through them. @@ -73,12 +74,13 @@ def _init_langchain_model( raise err @staticmethod - def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]: - """Returns langchain.llms.type_to_cls_dict. - RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict. + def get_type_to_cls_dict() -> Dict[str, Type["langchain_community.llms.BaseLLM"]]: + """Returns langchain_community.llms.type_to_cls_dict. + RETURNS (Dict[str, Type[langchain_community.llms.BaseLLM]]): langchain_community.llms.type_to_cls_dict. """ return { - llm_id: getattr(langchain.llms, llm_id) for llm_id in langchain.llms.__all__ + llm_id: getattr(langchain_community.llms, llm_id) + for llm_id in langchain_community.llms.__all__ } def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: @@ -90,15 +92,16 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: @staticmethod def query_langchain( - model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]] + model: "langchain_community.llms.BaseLLM", prompts: Iterable[Iterable[Any]] ) -> Iterable[Iterable[Any]]: """Query LangChain model naively. - model (langchain.llms.BaseLLM): LangChain model. + model (langchain_community.llms.BaseLLM): LangChain model. prompts (Iterable[Iterable[Any]]): Prompts to execute. RETURNS (Iterable[Iterable[Any]]): LLM responses. """ - assert callable(model) - return [[model(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts] + return [ + [model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts + ] @staticmethod def _check_installation() -> None: @@ -115,7 +118,7 @@ def langchain_model( name: str, query: Optional[ Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[str]]], + ["langchain_community.llms.BaseLLM", Iterable[Iterable[str]]], Iterable[Iterable[str]], ] ] = None, @@ -170,11 +173,12 @@ def register_models() -> None: @registry.llm_queries("spacy.CallLangChain.v1") def query_langchain() -> ( Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ] ): """Returns query Callable for LangChain. - RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing - simple prompts on the specified LangChain model. + RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable + executing simple prompts on the specified LangChain model. """ return LangChain.query_langchain