Skip to content

Commit

Permalink
Merge pull request #692 from Undertone0809/v1.16.7/opt-tool-schema
Browse files Browse the repository at this point in the history
pref: optimize when use define_tool define tool
  • Loading branch information
Undertone0809 authored May 26, 2024
2 parents 5a64460 + 2988ed9 commit 8a526ac
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 16 deletions.
31 changes: 17 additions & 14 deletions promptulate/tools/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,23 +303,26 @@ def from_define_tool(
Returns:
A ToolImpl instance.
"""
if not parameters:
schema: dict = function_to_tool_schema(callback)
elif isinstance(parameters, dict) and _validate_refined_schema(parameters):
schema: dict = parameters
elif isinstance(parameters, type) and issubclass(parameters, BaseModel):
schema: dict = _pydantic_to_refined_schema(parameters)
_name = name or callback.__name__
_description = description or callback.__doc__ or ""

if parameters:
if isinstance(parameters, dict):
schema = parameters
elif isinstance(parameters, type) and issubclass(parameters, BaseModel):
schema = _pydantic_to_refined_schema(parameters)
else:
raise TypeError(
f"{[cls.__name__]} parameters must be BaseModel or JSON schema."
) # noqa
else:
raise TypeError(
f"{[cls.__name__]} parameters must be BaseModel or JSON schema."
)

_description = description or ""
_doc = callback.__doc__ or ""
schema = function_to_tool_schema(callback)
schema["name"] = _name
schema["description"] = _description

return cls(
name=name or callback.__name__,
description=f"{_description}\n{_doc}",
name=schema["name"],
description=schema["description"],
callback=callback,
parameters=schema,
)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ name = "promptulate"
readme = "README.md"
homepage = "https://github.com/Undertone0809/promptulate"
repository = "https://github.com/Undertone0809/promptulate"
version = "1.16.6"
version = "1.16.7"
keywords = [
"promptulate",
"pne",
Expand Down
40 changes: 39 additions & 1 deletion tests/tools/test_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,32 @@ def func_4(a: str, b: int):
"type": "object",
}

func1_schema_of_define_tool = {
"name": "mock tool",
"description": "mock tool description\nmock func 0",
"properties": {},
"type": "object",
}

func2_schema_of_define_tool = {
"name": "mock tool",
"description": "mock tool description\nmock func 1",
"properties": {
"a": {
"type": "string",
},
"b": {
"type": "integer",
},
},
"required": ["a", "b"],
"type": "object",
}


def test_define_tool():
"""Test initialize tool by define_tool function."""
# test func 0
tool = define_tool(
name="mock tool",
description="mock tool description",
Expand All @@ -125,7 +148,22 @@ def test_define_tool():
resp: str = tool.run()
assert resp == "result"

assert tool.to_schema() == func_0_schema
assert tool.to_schema() == func1_schema_of_define_tool

# test func 1
tool = define_tool(
name="mock tool",
description="mock tool description",
callback=func_1,
)

assert tool.name == "mock tool"
assert tool.description == "mock tool description\nmock func 1"

resp: str = tool.run(a="a", b=1)
assert resp == "result"

assert tool.to_schema() == func2_schema_of_define_tool


def test_tool_cls():
Expand Down

0 comments on commit 8a526ac

Please sign in to comment.