Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add workaround for langchain model ID issue #374

Merged
merged 8 commits into from
Nov 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ filterwarnings = [
"ignore:^.*`__get_validators__` is deprecated.*",
"ignore:^.*The `construct` method is deprecated.*",
"ignore:^.*Skipping device Apple Paravirtual device that does not support Metal 2.0.*",
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated."
"ignore:^.*Pydantic V1 style `@validator` validators are deprecated.*"
]
markers = [
"external: interacts with a (potentially cost-incurring) third-party API",
Expand Down
5 changes: 3 additions & 2 deletions spacy_llm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
PYDANTIC_V2 = VERSION.startswith("2.")

if PYDANTIC_V2:
from pydantic.v1 import BaseModel, ValidationError, validator # noqa: F401
from pydantic.v1 import BaseModel, ExtraError, ValidationError # noqa: F401
from pydantic.v1 import validator
else:
from pydantic import BaseModel, ValidationError, validator # noqa: F401
from pydantic import BaseModel, ExtraError, ValidationError, validator # noqa: F401
69 changes: 46 additions & 23 deletions spacy_llm/models/langchain/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

from confection import SimpleFrozenDict

from ...compat import has_langchain, langchain
from ...compat import ExtraError, ValidationError, has_langchain, langchain
from ...registry import registry

try:
from langchain import base_language # noqa: F401
from langchain import llms # noqa: F401
except (ImportError, AttributeError):
llms = None
Expand All @@ -18,30 +17,59 @@ def __init__(
name: str,
api: str,
config: Dict[Any, Any],
query: Callable[
["langchain.base_language.BaseLanguageModel", Iterable[Any]],
Iterable[Any],
],
query: Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]],
):
"""Initializes model instance for integration APIs.
name (str): Name of LangChain model to instantiate.
api (str): Name of class/API.
config (Dict[Any, Any]): Config passed on to LangChain model.
query (Callable[[Any, Iterable[_PromptType]], Iterable[_ResponseType]]): Callable executing LLM prompts when
query (Callable[[langchain.llms.BaseLLM, Iterable[Any]], Iterable[Any]]): Callable executing LLM prompts when
supplied with the `integration` object.
"""
self._langchain_model = LangChain.get_type_to_cls_dict()[api](
model_name=name, **config
)
self._langchain_model = LangChain._init_langchain_model(name, api, config)
self.query = query
self._check_installation()

@classmethod
def _init_langchain_model(
cls, name: str, api: str, config: Dict[Any, Any]
) -> "langchain.llms.BaseLLM":
"""Initializes langchain model. langchain expects a range of different model ID argument names, depending on the
model class. There doesn't seem to be a clean way to determine those from the outset, we'll fail our way through
them.
Includes error checks for model ID arguments.
name (str): Name of LangChain model to instantiate.
api (str): Name of class/API.
config (Dict[Any, Any]): Config passed on to LangChain model.
"""
model_init_args = ["model", "model_name", "model_id"]
for model_init_arg in model_init_args:
try:
return cls.get_type_to_cls_dict()[api](
**{model_init_arg: name}, **config
)
except ValidationError as err:
if model_init_arg == model_init_args[-1]:
# If init error indicates that model ID arg is extraneous: raise error with hint on how to proceed.
if any(
[
rerr
for rerr in err.raw_errors
if isinstance(rerr.exc, ExtraError)
and model_init_arg in rerr.loc_tuple()
]
):
raise ValueError(
"Couldn't initialize LangChain model with known model ID arguments. Please report this to "
"https://github.com/explosion/spacy-llm/issues. Thank you!"
) from err
# Otherwise: raise error as-is.
raise err

@staticmethod
def get_type_to_cls_dict() -> Dict[
str, Type["langchain.base_language.BaseLanguageModel"]
]:
def get_type_to_cls_dict() -> Dict[str, Type["langchain.llms.BaseLLM"]]:
"""Returns langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.base_language.BaseLanguageModel]]): langchain.llms.type_to_cls_dict.
RETURNS (Dict[str, Type[langchain.llms.BaseLLM]]): langchain.llms.type_to_cls_dict.
"""
return getattr(langchain.llms, "type_to_cls_dict")

Expand All @@ -54,10 +82,10 @@ def __call__(self, prompts: Iterable[Any]) -> Iterable[Any]:

@staticmethod
def query_langchain(
model: "langchain.base_language.BaseLanguageModel", prompts: Iterable[Any]
model: "langchain.llms.BaseLLM", prompts: Iterable[Any]
) -> Iterable[Any]:
"""Query LangChain model naively.
model (langchain.base_language.BaseLanguageModel): LangChain model.
model (langchain.llms.BaseLLM): LangChain model.
prompts (Iterable[Any]): Prompts to execute.
RETURNS (Iterable[Any]): LLM responses.
"""
Expand All @@ -77,10 +105,7 @@ def _langchain_model_maker(class_id: str):
def langchain_model(
name: str,
query: Optional[
Callable[
["langchain.base_language.BaseLanguageModel", Iterable[str]],
Iterable[str],
]
Callable[["langchain.llms.BaseLLM", Iterable[str]], Iterable[str]]
] = None,
config: Dict[Any, Any] = SimpleFrozenDict(),
langchain_class_id: str = class_id,
Expand Down Expand Up @@ -123,9 +148,7 @@ def register_models() -> None:

@registry.llm_queries("spacy.CallLangChain.v1")
def query_langchain() -> (
Callable[
["langchain.base_language.BaseLanguageModel", Iterable[Any]], Iterable[Any]
]
Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]
):
"""Returns query Callable for LangChain.
RETURNS (Callable[["langchain.llms.BaseLLM", Iterable[Any]], Iterable[Any]]:): Callable executing simple prompts on
Expand Down
11 changes: 10 additions & 1 deletion spacy_llm/tests/models/test_langchain.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,31 @@
import os
from typing import List

import pytest
import spacy

from spacy_llm.compat import has_langchain
from spacy_llm.models.langchain import LangChain
from spacy_llm.tests.compat import has_azure_openai_key

PIPE_CFG = {
"model": {
"@llm_models": "langchain.OpenAI.v1",
"name": "ada",
"config": {"temperature": 0.3},
"query": {"@llm_queries": "spacy.CallLangChain.v1"},
},
"task": {"@llm_tasks": "spacy.NoOp.v1"},
}


def langchain_model_reg_handles() -> List[str]:
"""Returns a list of all LangChain model reg handles."""
return [
f"langchain.{cls.__name__}.v1"
for class_id, cls in LangChain.get_type_to_cls_dict().items()
]


@pytest.mark.external
@pytest.mark.skipif(has_langchain is False, reason="LangChain is not installed")
def test_initialization():
Expand Down