From 2988ed996761eb90d29b04231c7b10c664dfc8be Mon Sep 17 00:00:00 2001 From: zeeland Date: Sun, 26 May 2024 14:11:35 +0800 Subject: [PATCH] pref: optimize when use define_tool define tool --- promptulate/tools/base.py | 31 ++++++++++++++++-------------- pyproject.toml | 2 +- tests/tools/test_tool.py | 40 ++++++++++++++++++++++++++++++++++++++- 3 files changed, 57 insertions(+), 16 deletions(-) diff --git a/promptulate/tools/base.py b/promptulate/tools/base.py index cfa7243a..1c896224 100644 --- a/promptulate/tools/base.py +++ b/promptulate/tools/base.py @@ -303,23 +303,26 @@ def from_define_tool( Returns: A ToolImpl instance. """ - if not parameters: - schema: dict = function_to_tool_schema(callback) - elif isinstance(parameters, dict) and _validate_refined_schema(parameters): - schema: dict = parameters - elif isinstance(parameters, type) and issubclass(parameters, BaseModel): - schema: dict = _pydantic_to_refined_schema(parameters) + _name = name or callback.__name__ + _description = description or callback.__doc__ or "" + + if parameters: + if isinstance(parameters, dict): + schema = parameters + elif isinstance(parameters, type) and issubclass(parameters, BaseModel): + schema = _pydantic_to_refined_schema(parameters) + else: + raise TypeError( + f"{[cls.__name__]} parameters must be BaseModel or JSON schema." + ) # noqa else: - raise TypeError( - f"{[cls.__name__]} parameters must be BaseModel or JSON schema." - ) - - _description = description or "" - _doc = callback.__doc__ or "" + schema = function_to_tool_schema(callback) + schema["name"] = _name + schema["description"] = _description return cls( - name=name or callback.__name__, - description=f"{_description}\n{_doc}", + name=schema["name"], + description=schema["description"], callback=callback, parameters=schema, ) diff --git a/pyproject.toml b/pyproject.toml index 2be29b55..3359019b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ name = "promptulate" readme = "README.md" homepage = "https://github.com/Undertone0809/promptulate" repository = "https://github.com/Undertone0809/promptulate" -version = "1.16.6" +version = "1.16.7" keywords = [ "promptulate", "pne", diff --git a/tests/tools/test_tool.py b/tests/tools/test_tool.py index f9473692..0bd994bf 100644 --- a/tests/tools/test_tool.py +++ b/tests/tools/test_tool.py @@ -110,9 +110,32 @@ def func_4(a: str, b: int): "type": "object", } +func1_schema_of_define_tool = { + "name": "mock tool", + "description": "mock tool description\nmock func 0", + "properties": {}, + "type": "object", +} + +func2_schema_of_define_tool = { + "name": "mock tool", + "description": "mock tool description\nmock func 1", + "properties": { + "a": { + "type": "string", + }, + "b": { + "type": "integer", + }, + }, + "required": ["a", "b"], + "type": "object", +} + def test_define_tool(): """Test initialize tool by define_tool function.""" + # test func 0 tool = define_tool( name="mock tool", description="mock tool description", @@ -125,7 +148,22 @@ def test_define_tool(): resp: str = tool.run() assert resp == "result" - assert tool.to_schema() == func_0_schema + assert tool.to_schema() == func1_schema_of_define_tool + + # test func 1 + tool = define_tool( + name="mock tool", + description="mock tool description", + callback=func_1, + ) + + assert tool.name == "mock tool" + assert tool.description == "mock tool description\nmock func 1" + + resp: str = tool.run(a="a", b=1) + assert resp == "result" + + assert tool.to_schema() == func2_schema_of_define_tool def test_tool_cls():