diff --git a/spacy_llm/compat.py b/spacy_llm/compat.py index 85f39036..65f32b07 100644 --- a/spacy_llm/compat.py +++ b/spacy_llm/compat.py @@ -19,10 +19,12 @@ try: import langchain + import langchain_community has_langchain = True except (ImportError, AttributeError): langchain = None + langchain_community = None has_langchain = False try: diff --git a/spacy_llm/models/langchain/model.py b/spacy_llm/models/langchain/model.py index 6ede8b38..cded9fd7 100644 --- a/spacy_llm/models/langchain/model.py +++ b/spacy_llm/models/langchain/model.py @@ -2,11 +2,11 @@ from confection import SimpleFrozenDict -from ...compat import ExtraError, ValidationError, has_langchain, langchain +from ...compat import ExtraError, ValidationError, has_langchain, langchain_community from ...registry import registry try: - from langchain import llms # noqa: F401 + from langchain_community import llms # noqa: F401 except (ImportError, AttributeError): llms = None @@ -18,7 +18,8 @@ def __init__( api: str, config: Dict[Any, Any], query: Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ], context_length: Optional[int], ): @@ -26,8 +27,8 @@ def __init__( name (str): Name of LangChain model to instantiate. api (str): Name of class/API. config (Dict[Any, Any]): Config passed on to LangChain model. - query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing - LLM prompts when supplied with the model instance. + query (Callable[[langchain_community.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable + executing LLM prompts when supplied with the model instance. context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context length provided, prompts can't be sharded. """ @@ -39,7 +40,7 @@ def __init__( @classmethod def _init_langchain_model( cls, name: str, api: str, config: Dict[Any, Any] - ) -> "langchain.llms.BaseLLM": + ) -> "langchain_community.llms.BaseLLM": """Initializes langchain model. langchain expects a range of different model ID argument names, depending on the model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through them. @@ -73,12 +74,13 @@ def _init_langchain_model( raise err @staticmethod - def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]: - """Returns langchain.llms.type_to_cls_dict. - RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict. + def get_type_to_cls_dict() -> Dict[str, Type["langchain_community.llms.BaseLLM"]]: + """Returns langchain_community.llms.type_to_cls_dict. + RETURNS (Dict[str, Type[langchain_community.llms.BaseLLM]]): langchain_community.llms.type_to_cls_dict. """ return { - llm_id: getattr(langchain.llms, llm_id) for llm_id in langchain.llms.__all__ + llm_id: getattr(langchain_community.llms, llm_id) + for llm_id in langchain_community.llms.__all__ } def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: @@ -90,10 +92,10 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]: @staticmethod def query_langchain( - model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]] + model: "langchain_community.llms.BaseLLM", prompts: Iterable[Iterable[Any]] ) -> Iterable[Iterable[Any]]: """Query LangChain model naively. - model (langchain.llms.BaseLLM): LangChain model. + model (langchain_community.llms.BaseLLM): LangChain model. prompts (Iterable[Iterable[Any]]): Prompts to execute. RETURNS (Iterable[Iterable[Any]]): LLM responses. """ @@ -116,7 +118,7 @@ def langchain_model( name: str, query: Optional[ Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[str]]], + ["langchain_community.llms.BaseLLM", Iterable[Iterable[str]]], Iterable[Iterable[str]], ] ] = None, @@ -171,11 +173,12 @@ def register_models() -> None: @registry.llm_queries("spacy.CallLangChain.v1") def query_langchain() -> ( Callable[ - ["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]] + ["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], + Iterable[Iterable[Any]], ] ): """Returns query Callable for LangChain. - RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing - simple prompts on the specified LangChain model. + RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable + executing simple prompts on the specified LangChain model. """ return LangChain.query_langchain