Skip to content

Commit

Permalink
core[minor]: Relax constraints on type checking for tools and parsers (
Browse files Browse the repository at this point in the history
…langchain-ai#24459)

This will allow tools and parsers to accept pydantic models from any of
the
following namespaces:

* pydantic.BaseModel with pydantic 1
* pydantic.BaseModel with pydantic 2
* pydantic.v1.BaseModel with pydantic 2
  • Loading branch information
eyurtsev authored and olgamurraft committed Aug 16, 2024
1 parent 1933e22 commit d314895
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 30 deletions.
7 changes: 4 additions & 3 deletions libs/core/langchain_core/output_parsers/openai_tools.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import copy
import json
from json import JSONDecodeError
from typing import Any, Dict, List, Optional, Type
from typing import Any, Dict, List, Optional

from langchain_core.exceptions import OutputParserException
from langchain_core.messages import AIMessage, InvalidToolCall
Expand All @@ -13,8 +13,9 @@
)
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.outputs import ChatGeneration, Generation
from langchain_core.pydantic_v1 import BaseModel, ValidationError
from langchain_core.pydantic_v1 import ValidationError
from langchain_core.utils.json import parse_partial_json
from langchain_core.utils.pydantic import TypeBaseModel


def parse_tool_call(
Expand Down Expand Up @@ -255,7 +256,7 @@ def parse_result(self, result: List[Generation], *, partial: bool = False) -> An
class PydanticToolsParser(JsonOutputToolsParser):
"""Parse tools from OpenAI response."""

tools: List[Type[BaseModel]]
tools: List[TypeBaseModel]
"""The tools to parse."""

# TODO: Support more granular streaming of objects. Currently only streams once all
Expand Down
26 changes: 13 additions & 13 deletions libs/core/langchain_core/output_parsers/pydantic.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,16 @@
import json
from typing import Generic, List, Type, TypeVar, Union
from typing import Generic, List, Type

import pydantic # pydantic: ignore

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers import JsonOutputParser
from langchain_core.outputs import Generation
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION

if PYDANTIC_MAJOR_VERSION < 2:
PydanticBaseModel = pydantic.BaseModel

else:
from pydantic.v1 import BaseModel # pydantic: ignore

# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore

TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)
from langchain_core.utils.pydantic import (
PYDANTIC_MAJOR_VERSION,
PydanticBaseModel,
TBaseModel,
)


class PydanticOutputParser(JsonOutputParser, Generic[TBaseModel]):
Expand Down Expand Up @@ -122,3 +115,10 @@ def OutputType(self) -> Type[TBaseModel]:
```
{schema}
```""" # noqa: E501

# Re-exporting types for backwards compatibility
__all__ = [
"PydanticBaseModel",
"PydanticOutputParser",
"TBaseModel",
]
14 changes: 11 additions & 3 deletions libs/core/langchain_core/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
)
from langchain_core.runnables.utils import accepts_context
from langchain_core.utils.pydantic import (
TypeBaseModel,
_create_subset_model,
is_basemodel_subclass,
)
Expand Down Expand Up @@ -332,8 +333,15 @@ class ChildTool(BaseTool):
You can provide few-shot examples as a part of the description.
"""
args_schema: Optional[Type[BaseModel]] = None
"""Pydantic model class to validate and parse the tool's input arguments."""
args_schema: Optional[TypeBaseModel] = None
"""Pydantic model class to validate and parse the tool's input arguments.
Args schema should be either:
- A subclass of pydantic.BaseModel.
or
- A subclass of pydantic.v1.BaseModel if accessing v1 namespace in pydantic 2
"""
return_direct: bool = False
"""Whether to return the tool's output directly.
Expand Down Expand Up @@ -891,7 +899,7 @@ class StructuredTool(BaseTool):
"""Tool that can operate on any number of inputs."""

description: str = ""
args_schema: Type[BaseModel] = Field(..., description="The tool schema.")
args_schema: TypeBaseModel = Field(..., description="The tool schema.")
"""The input arguments' schema."""
func: Optional[Callable[..., Any]]
"""The function to run when the tool is called."""
Expand Down
37 changes: 29 additions & 8 deletions libs/core/langchain_core/utils/pydantic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,14 @@
import inspect
import textwrap
from functools import wraps
from typing import Any, Callable, Dict, List, Optional, Type
from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union

from langchain_core.pydantic_v1 import BaseModel, root_validator
import pydantic # pydantic: ignore

from langchain_core.pydantic_v1 import (
BaseModel,
root_validator,
)


def get_pydantic_major_version() -> int:
Expand All @@ -23,6 +28,22 @@ def get_pydantic_major_version() -> int:
PYDANTIC_MAJOR_VERSION = get_pydantic_major_version()


if PYDANTIC_MAJOR_VERSION == 1:
PydanticBaseModel = pydantic.BaseModel
TypeBaseModel = Type[BaseModel]
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic.v1 import BaseModel # pydantic: ignore

# Union type needs to be last assignment to PydanticBaseModel to make mypy happy.
PydanticBaseModel = Union[BaseModel, pydantic.BaseModel] # type: ignore
TypeBaseModel = Union[Type[BaseModel], Type[pydantic.BaseModel]] # type: ignore
else:
raise ValueError(f"Unsupported Pydantic version: {PYDANTIC_MAJOR_VERSION}")


TBaseModel = TypeVar("TBaseModel", bound=PydanticBaseModel)


def is_basemodel_subclass(cls: Type) -> bool:
"""Check if the given class is a subclass of Pydantic BaseModel.
Expand All @@ -37,13 +58,13 @@ def is_basemodel_subclass(cls: Type) -> bool:
return False

if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore

if issubclass(cls, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore

if issubclass(cls, BaseModelV2):
return True
Expand All @@ -65,13 +86,13 @@ def is_basemodel_instance(obj: Any) -> bool:
* pydantic.v1.BaseModel in Pydantic 2.x
"""
if PYDANTIC_MAJOR_VERSION == 1:
from pydantic import BaseModel as BaseModelV1Proper
from pydantic import BaseModel as BaseModelV1Proper # pydantic: ignore

if isinstance(obj, BaseModelV1Proper):
return True
elif PYDANTIC_MAJOR_VERSION == 2:
from pydantic import BaseModel as BaseModelV2
from pydantic.v1 import BaseModel as BaseModelV1
from pydantic import BaseModel as BaseModelV2 # pydantic: ignore
from pydantic.v1 import BaseModel as BaseModelV1 # pydantic: ignore

if isinstance(obj, BaseModelV2):
return True
Expand Down
116 changes: 115 additions & 1 deletion libs/core/tests/unit_tests/output_parsers/test_openai_tools.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,21 @@
from typing import Any, AsyncIterator, Iterator, List

from langchain_core.messages import AIMessageChunk, BaseMessage, ToolCallChunk
import pytest

from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ToolCallChunk,
)
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
JsonOutputToolsParser,
PydanticToolsParser,
)
from langchain_core.outputs import ChatGeneration
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION

STREAMED_MESSAGES: list = [
AIMessageChunk(content=""),
Expand Down Expand Up @@ -518,3 +527,108 @@ async def test_partial_pydantic_output_parser_async() -> None:

actual = [p async for p in chain.astream(None)]
assert actual == EXPECTED_STREAMED_PYDANTIC


@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_v1() -> None:
"""Test with pydantic.v1.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore

class Forecast(pydantic.v1.BaseModel):
temperature: int
forecast: str

# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)

generation = ChatGeneration(
message=message,
)

assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]


@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 2, reason="This test is for pydantic 2")
def test_parse_with_different_pydantic_2_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 2."""
import pydantic # pydantic: ignore

class Forecast(pydantic.BaseModel):
temperature: int
forecast: str

# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)

generation = ChatGeneration(
message=message,
)

assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]


@pytest.mark.skipif(PYDANTIC_MAJOR_VERSION != 1, reason="This test is for pydantic 1")
def test_parse_with_different_pydantic_1_proper() -> None:
"""Test with pydantic.BaseModel from pydantic 1."""
import pydantic # pydantic: ignore

class Forecast(pydantic.BaseModel):
temperature: int
forecast: str

# Can't get pydantic to work here due to the odd typing of tryig to support
# both v1 and v2 in the same codebase.
parser = PydanticToolsParser(tools=[Forecast]) # type: ignore[list-item]
message = AIMessage(
content="",
tool_calls=[
{
"id": "call_OwL7f5PE",
"name": "Forecast",
"args": {"temperature": 20, "forecast": "Sunny"},
}
],
)

generation = ChatGeneration(
message=message,
)

assert parser.parse_result([generation]) == [
Forecast(
temperature=20,
forecast="Sunny",
)
]
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from langchain_core.exceptions import OutputParserException
from langchain_core.language_models import ParrotFakeChatModel
from langchain_core.output_parsers.json import JsonOutputParser
from langchain_core.output_parsers.pydantic import PydanticOutputParser, TBaseModel
from langchain_core.output_parsers.pydantic import PydanticOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION
from langchain_core.utils.pydantic import PYDANTIC_MAJOR_VERSION, TBaseModel

V1BaseModel = pydantic.BaseModel
if PYDANTIC_MAJOR_VERSION == 2:
Expand Down
37 changes: 37 additions & 0 deletions libs/core/tests/unit_tests/test_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -1526,3 +1526,40 @@ def _run(self, *args: Any, **kwargs: Any) -> str:
"title": "some_tool",
"type": "object",
}


@pytest.mark.parametrize("pydantic_model", TEST_MODELS)
def test_structured_tool_with_different_pydantic_versions(pydantic_model: Any) -> None:
"""This should test that one can type the args schema as a pydantic model."""
from langchain_core.tools import StructuredTool

def foo(a: int, b: str) -> str:
"""Hahaha"""
return "foo"

foo_tool = StructuredTool.from_function(
func=foo,
args_schema=pydantic_model,
)

assert foo_tool.invoke({"a": 5, "b": "hello"}) == "foo"

assert foo_tool.args_schema.schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}

assert foo_tool.get_input_schema().schema() == {
"properties": {
"a": {"title": "A", "type": "integer"},
"b": {"title": "B", "type": "string"},
},
"required": ["a", "b"],
"title": pydantic_model.__name__,
"type": "object",
}

0 comments on commit d314895

Please sign in to comment.