Skip to content

Commit

Permalink
Import from langchain_community instead from langchain.
Browse files Browse the repository at this point in the history
  • Loading branch information
rmitsch committed Jan 29, 2024
1 parent 7313dfc commit 3516ee1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 16 deletions.
2 changes: 2 additions & 0 deletions spacy_llm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

try:
import langchain
import langchain_community

has_langchain = True
except (ImportError, AttributeError):
langchain = None
langchain_community = None
has_langchain = False

try:
Expand Down
35 changes: 19 additions & 16 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from confection import SimpleFrozenDict

from ...compat import ExtraError, ValidationError, has_langchain, langchain
from ...compat import ExtraError, ValidationError, has_langchain, langchain_community
from ...registry import registry

try:
from langchain import llms # noqa: F401
from langchain_community import llms # noqa: F401
except (ImportError, AttributeError):
llms = None

Expand All @@ -18,16 +18,17 @@ def __init__(
api: str,
config: Dict[Any, Any],
query: Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
],
context_length: Optional[int],
):
"""Initializes model instance for integration APIs.
name (str): Name of LangChain model to instantiate.
api (str): Name of class/API.
config (Dict[Any, Any]): Config passed on to LangChain model.
query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
LLM prompts when supplied with the model instance.
query (Callable[[langchain_community.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
executing LLM prompts when supplied with the model instance.
context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context
length provided, prompts can't be sharded.
"""
Expand All @@ -39,7 +40,7 @@ def __init__(
@classmethod
def _init_langchain_model(
cls, name: str, api: str, config: Dict[Any, Any]
) -> "langchain.llms.BaseLLM":
) -> "langchain_community.llms.BaseLLM":
"""Initializes langchain model. langchain expects a range of different model ID argument names, depending on the
model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through
them.
Expand Down Expand Up @@ -73,12 +74,13 @@ def _init_langchain_model(
raise err

@staticmethod
def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]:
"""Returns langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict.
def get_type_to_cls_dict() -> Dict[str, Type["langchain_community.llms.BaseLLM"]]:
"""Returns langchain_community.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain_community.llms.BaseLLM]]): langchain_community.llms.type_to_cls_dict.
"""
return {
llm_id: getattr(langchain.llms, llm_id) for llm_id in langchain.llms.__all__
llm_id: getattr(langchain_community.llms, llm_id)
for llm_id in langchain_community.llms.__all__
}

def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
Expand All @@ -90,10 +92,10 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:

@staticmethod
def query_langchain(
model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
model: "langchain_community.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
) -> Iterable[Iterable[Any]]:
"""Query LangChain model naively.
model (langchain.llms.BaseLLM): LangChain model.
model (langchain_community.llms.BaseLLM): LangChain model.
prompts (Iterable[Iterable[Any]]): Prompts to execute.
RETURNS (Iterable[Iterable[Any]]): LLM responses.
"""
Expand All @@ -116,7 +118,7 @@ def langchain_model(
name: str,
query: Optional[
Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[str]]],
["langchain_community.llms.BaseLLM", Iterable[Iterable[str]]],
Iterable[Iterable[str]],
]
] = None,
Expand Down Expand Up @@ -171,11 +173,12 @@ def register_models() -> None:
@registry.llm_queries("spacy.CallLangChain.v1")
def query_langchain() -> (
Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
]
):
"""Returns query Callable for LangChain.
RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
simple prompts on the specified LangChain model.
RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
executing simple prompts on the specified LangChain model.
"""
return LangChain.query_langchain

0 comments on commit 3516ee1

Please sign in to comment.