Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adjust LangChain usage in line with langchain >= 0.1 #433

Merged
merged 3 commits into from
Jan 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ filterwarnings = [
"ignore:^.*`__get_validators__` is deprecated.*",
"ignore:^.*The `construct` method is deprecated.*",
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*"
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*",
"ignore:^.*was deprecated in langchain-community.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
2 changes: 1 addition & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ mypy>=0.990,<1.1.0; platform_machine != "aarch64" and python_version >= "3.7"
black==22.3.0
types-requests==2.28.11.16
# Prompting libraries needed for testing
langchain==0.0.331; python_version>="3.9"
langchain>=0.1,<0.2; python_version>="3.9"
# Workaround for LangChain bug: pin OpenAI version. To be removed after LangChain has been fixed - see
# https://github.com/langchain-ai/langchain/issues/12967.
openai>=0.27,<=0.28.1; python_version>="3.9"
Expand Down
2 changes: 1 addition & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ spacy_misc =

[options.extras_require]
langchain =
langchain==0.0.335
langchain>=0.1,<0.2
transformers =
torch>=1.13.1,<2.0
transformers>=4.28.1,<5.0
Expand Down
2 changes: 2 additions & 0 deletions spacy_llm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

try:
import langchain
import langchain_community

has_langchain = True
except (ImportError, AttributeError):
langchain = None
langchain_community = None
has_langchain = False

try:
Expand Down
40 changes: 22 additions & 18 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

from confection import SimpleFrozenDict

from ...compat import ExtraError, ValidationError, has_langchain, langchain
from ...compat import ExtraError, ValidationError, has_langchain, langchain_community
from ...registry import registry

try:
from langchain import llms # noqa: F401
from langchain_community import llms # noqa: F401
except (ImportError, AttributeError):
llms = None

Expand All @@ -18,16 +18,17 @@ def __init__(
api: str,
config: Dict[Any, Any],
query: Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
],
context_length: Optional[int],
):
"""Initializes model instance for integration APIs.
name (str): Name of LangChain model to instantiate.
api (str): Name of class/API.
config (Dict[Any, Any]): Config passed on to LangChain model.
query (Callable[[langchain.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
LLM prompts when supplied with the model instance.
query (Callable[[langchain_community.llms.BaseLLM, Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
executing LLM prompts when supplied with the model instance.
context_length (Optional[int]): Context length for this model. Only necessary for sharding. If no context
length provided, prompts can't be sharded.
"""
Expand All @@ -39,7 +40,7 @@ def __init__(
@classmethod
def _init_langchain_model(
cls, name: str, api: str, config: Dict[Any, Any]
) -> "langchain.llms.BaseLLM":
) -> "langchain_community.llms.BaseLLM":
"""Initializes langchain model. langchain expects a range of different model ID argument names, depending on the
model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through
them.
Expand Down Expand Up @@ -73,12 +74,13 @@ def _init_langchain_model(
raise err

@staticmethod
def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]:
"""Returns langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict.
def get_type_to_cls_dict() -> Dict[str, Type["langchain_community.llms.BaseLLM"]]:
"""Returns langchain_community.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain_community.llms.BaseLLM]]): langchain_community.llms.type_to_cls_dict.
"""
return {
llm_id: getattr(langchain.llms, llm_id) for llm_id in langchain.llms.__all__
llm_id: getattr(langchain_community.llms, llm_id)
for llm_id in langchain_community.llms.__all__
}

def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:
Expand All @@ -90,15 +92,16 @@ def __call__(self, prompts: Iterable[Iterable[Any]]) -> Iterable[Iterable[Any]]:

@staticmethod
def query_langchain(
model: "langchain.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
model: "langchain_community.llms.BaseLLM", prompts: Iterable[Iterable[Any]]
) -> Iterable[Iterable[Any]]:
"""Query LangChain model naively.
model (langchain.llms.BaseLLM): LangChain model.
model (langchain_community.llms.BaseLLM): LangChain model.
prompts (Iterable[Iterable[Any]]): Prompts to execute.
RETURNS (Iterable[Iterable[Any]]): LLM responses.
"""
assert callable(model)
return [[model(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts]
return [
[model.invoke(pr) for pr in prompts_for_doc] for prompts_for_doc in prompts
]

@staticmethod
def _check_installation() -> None:
Expand All @@ -115,7 +118,7 @@ def langchain_model(
name: str,
query: Optional[
Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[str]]],
["langchain_community.llms.BaseLLM", Iterable[Iterable[str]]],
Iterable[Iterable[str]],
]
] = None,
Expand Down Expand Up @@ -170,11 +173,12 @@ def register_models() -> None:
@registry.llm_queries("spacy.CallLangChain.v1")
def query_langchain() -> (
Callable[
["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]
["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]],
Iterable[Iterable[Any]],
]
):
"""Returns query Callable for LangChain.
RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable executing
simple prompts on the specified LangChain model.
RETURNS (Callable[["langchain_community.llms.BaseLLM", Iterable[Iterable[Any]]], Iterable[Iterable[Any]]]): Callable
executing simple prompts on the specified LangChain model.
"""
return LangChain.query_langchain
Loading