Skip to content

Commit

Permalink
Add workaround for langchain model ID issue (#374)
Browse files Browse the repository at this point in the history
* Add workaround for langchain model ID issue.

* Refactor.

* Extend filterwarnings.

* Revert filterwarnings.

* Extend filterwarnings.

* Fix pydantic imports.

* Extend filterwarnings.
  • Loading branch information
rmitsch authored Nov 17, 2023
1 parent a0e1080 commit 6f2b241
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 27 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ filterwarnings = [
"ignore:^.*`__get_validators__` is deprecated.*",
"ignore:^.*The `construct` method is deprecated.*",
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated."
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
5 changes: 3 additions & 2 deletions spacy_llm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
PYDANTIC_V2 = VERSION.startswith("2.")

if PYDANTIC_V2:
from pydantic.v1 import BaseModel, ValidationError, validator # noqa: F401
from pydantic.v1 import BaseModel, ExtraError, ValidationError # noqa: F401
from pydantic.v1 import validator
else:
from pydantic import BaseModel, ValidationError, validator # noqa: F401
from pydantic import BaseModel, ExtraError, ValidationError, validator # noqa: F401
69 changes: 46 additions & 23 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from confection import SimpleFrozenDict

from ...compat import has_langchain, langchain
from ...compat import ExtraError, ValidationError, has_langchain, langchain
from ...registry import registry

try:
from langchain import base_language # noqa: F401
from langchain import llms # noqa: F401
except (ImportError, AttributeError):
llms = None
Expand All @@ -18,30 +17,59 @@ def __init__(
name: str,
api: str,
config: Dict[Any, Any],
query: Callable[
["langchain.base_language.BaseLanguageModel", Iterable[Any]],
Iterable[Any],
],
query: Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]],
):
"""Initializes model instance for integration APIs.
name (str): Name of LangChain model to instantiate.
api (str): Name of class/API.
config (Dict[Any, Any]): Config passed on to LangChain model.
query (Callable[[Any, Iterable[_PromptType]], Iterable[_ResponseType]]): Callable executing LLM prompts when
query (Callable[[langchain.llms.BaseLLM, Iterable[Any]], Iterable[Any]]): Callable executing LLM prompts when
supplied with the `integration` object.
"""
self._langchain_model = LangChain.get_type_to_cls_dict()[api](
model_name=name, **config
)
self._langchain_model = LangChain._init_langchain_model(name, api, config)
self.query = query
self._check_installation()

@classmethod
def _init_langchain_model(
cls, name: str, api: str, config: Dict[Any, Any]
) -> "langchain.llms.BaseLLM":
"""Initializes langchain model. langchain expects a range of different model ID argument names, depending on the
model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through
them.
Includes error checks for model ID arguments.
name (str): Name of LangChain model to instantiate.
api (str): Name of class/API.
config (Dict[Any, Any]): Config passed on to LangChain model.
"""
model_init_args = ["model", "model_name", "model_id"]
for model_init_arg in model_init_args:
try:
return cls.get_type_to_cls_dict()[api](
**{model_init_arg: name}, **config
)
except ValidationError as err:
if model_init_arg == model_init_args[-1]:
# If init error indicates that model ID arg is extraneous: raise error with hint on how to proceed.
if any(
[
rerr
for rerr in err.raw_errors
if isinstance(rerr.exc, ExtraError)
and model_init_arg in rerr.loc_tuple()
]
):
raise ValueError(
"Couldn't initialize LangChain model with known model ID arguments. Please report this to "
"https://github.com/explosion/spacy-llm/issues. Thank you!"
) from err
# Otherwise: raise error as-is.
raise err

@staticmethod
def get_type_to_cls_dict() -> Dict[
str, Type["langchain.base_language.BaseLanguageModel"]
]:
def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]:
"""Returns langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.base_language.BaseLanguageModel]]): langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict.
"""
return getattr(langchain.llms, "type_to_cls_dict")

Expand All @@ -54,10 +82,10 @@ def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]:

@staticmethod
def query_langchain(
model: "langchain.base_language.BaseLanguageModel", prompts: Iterable[Any]
model: "langchain.llms.BaseLLM", prompts: Iterable[Any]
) -> Iterable[Any]:
"""Query LangChain model naively.
model (langchain.base_language.BaseLanguageModel): LangChain model.
model (langchain.llms.BaseLLM): LangChain model.
prompts (Iterable[Any]): Prompts to execute.
RETURNS (Iterable[Any]): LLM responses.
"""
Expand All @@ -77,10 +105,7 @@ def _langchain_model_maker(class_id: str):
def langchain_model(
name: str,
query: Optional[
Callable[
["langchain.base_language.BaseLanguageModel", Iterable[str]],
Iterable[str],
]
Callable[["langchain.llms.BaseLLM", Iterable[str]], Iterable[str]]
] = None,
config: Dict[Any, Any] = SimpleFrozenDict(),
langchain_class_id: str = class_id,
Expand Down Expand Up @@ -123,9 +148,7 @@ def register_models() -> None:

@registry.llm_queries("spacy.CallLangChain.v1")
def query_langchain() -> (
Callable[
["langchain.base_language.BaseLanguageModel", Iterable[Any]], Iterable[Any]
]
Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]
):
"""Returns query Callable for LangChain.
RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]:): Callable executing simple prompts on
Expand Down
11 changes: 10 additions & 1 deletion spacy_llm/tests/models/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
import os
from typing import List

import pytest
import spacy

from spacy_llm.compat import has_langchain
from spacy_llm.models.langchain import LangChain
from spacy_llm.tests.compat import has_azure_openai_key

PIPE_CFG = {
"model": {
"@llm_models": "langchain.OpenAI.v1",
"name": "ada",
"config": {"temperature": 0.3},
"query": {"@llm_queries": "spacy.CallLangChain.v1"},
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
}


def langchain_model_reg_handles() -> List[str]:
"""Returns a list of all LangChain model reg handles."""
return [
f"langchain.{cls.__name__}.v1"
for class_id, cls in LangChain.get_type_to_cls_dict().items()
]


@pytest.mark.external
@pytest.mark.skipif(has_langchain is False, reason="LangChain is not installed")
def test_initialization():
Expand Down

0 comments on commit 6f2b241

Please sign in to comment.