From b98eabac216273a634f32dc54810c684cd0c3e7b Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 16 Oct 2024 18:39:03 -0500 Subject: [PATCH 1/5] Kwarg injection with `@activity` decorated functions --- docs/griptape-tools/index.md | 11 +++++++-- docs/griptape-tools/src/index_1.py | 36 ++++++++++++++++++++++++++++++ griptape/tools/aws_iam/tool.py | 4 ++-- griptape/tools/aws_s3/tool.py | 2 +- griptape/tools/date_time/tool.py | 2 +- griptape/tools/web_search/tool.py | 8 +++---- griptape/utils/decorators.py | 27 +++++++++++++++++++--- tests/mocks/mock_tool/tool.py | 24 ++++++++++---------- 8 files changed, 88 insertions(+), 26 deletions(-) diff --git a/docs/griptape-tools/index.md b/docs/griptape-tools/index.md index f483d493f..47bf71f9e 100644 --- a/docs/griptape-tools/index.md +++ b/docs/griptape-tools/index.md @@ -2,12 +2,19 @@ Tools give the LLM abilities to invoke outside APIs, reference data sets, and ge Griptape tools are special Python classes that LLMs can use to accomplish specific goals. Here is an example custom tool for generating a random number: +A tool can have many "activities" as denoted by the `@activity` decorator. Each activity has a description (used to provide context to the LLM), and the input schema that the LLM must follow in order to use the tool. + +When a function is decorated with `@activity`, the decorator injects keyword arguments into the function according to the schema. There are also two Griptape-provided keyword arguments: `params: dict` and `values: dict`. + +!!! info + If your schema defines any parameters named `params` or `values`, they will be overwritten by the Griptape-provided arguments. + +In the following example, all `@activity` decorated functions will result in the same value, but the method signature is defined in different ways. + ```python --8<-- "docs/griptape-tools/src/index_1.py" ``` -A tool can have many "activities" as denoted by the `@activity` decorator. Each activity has a description (used to provide context to the LLM), and the input schema that the LLM must follow in order to use the tool. - Output artifacts from all tool activities (except for `InfoArtifact` and `ErrorArtifact`) go to short-term `TaskMemory`. To disable that behavior set the `off_prompt` tool parameter to `False`: We provide a set of official Griptape Tools for accessing and processing data. You can also [build your own tools](./custom-tools/index.md). diff --git a/docs/griptape-tools/src/index_1.py b/docs/griptape-tools/src/index_1.py index 7929574d4..618592b24 100644 --- a/docs/griptape-tools/src/index_1.py +++ b/docs/griptape-tools/src/index_1.py @@ -1,4 +1,7 @@ +from __future__ import annotations + import random +import typing from schema import Literal, Optional, Schema @@ -19,5 +22,38 @@ class RandomNumberGenerator(BaseTool): def generate(self, params: dict) -> TextArtifact: return TextArtifact(str(round(random.random(), params["values"].get("decimals")))) + @activity( + config={ + "description": "Can be used to generate random numbers", + "schema": Schema( + {Optional(Literal("decimals", description="Number of decimals to round the random number to")): int} + ), + } + ) + def generate_with_decimals(self, decimals: typing.Optional[int]) -> TextArtifact: + return TextArtifact(str(round(random.random(), decimals))) + + @activity( + config={ + "description": "Can be used to generate random numbers", + "schema": Schema( + {Optional(Literal("decimals", description="Number of decimals to round the random number to")): int} + ), + } + ) + def generate_with_values(self, values: dict) -> TextArtifact: + return TextArtifact(str(round(random.random(), values.get("decimals")))) + + @activity( + config={ + "description": "Can be used to generate random numbers", + "schema": Schema( + {Optional(Literal("decimals", description="Number of decimals to round the random number to")): int} + ), + } + ) + def generate_with_kwargs(self, **kwargs) -> TextArtifact: + return TextArtifact(str(round(random.random(), kwargs.get("decimals")))) + RandomNumberGenerator() diff --git a/griptape/tools/aws_iam/tool.py b/griptape/tools/aws_iam/tool.py index 6c1bed054..d5b20d56a 100644 --- a/griptape/tools/aws_iam/tool.py +++ b/griptape/tools/aws_iam/tool.py @@ -46,7 +46,7 @@ def get_user_policy(self, params: dict) -> TextArtifact | ErrorArtifact: return ErrorArtifact(f"error returning policy document: {e}") @activity(config={"description": "Can be used to list AWS MFA Devices"}) - def list_mfa_devices(self, _: dict) -> ListArtifact | ErrorArtifact: + def list_mfa_devices(self) -> ListArtifact | ErrorArtifact: try: devices = self.client.list_mfa_devices() return ListArtifact([TextArtifact(str(d)) for d in devices["MFADevices"]]) @@ -76,7 +76,7 @@ def list_user_policies(self, params: dict) -> ListArtifact | ErrorArtifact: return ErrorArtifact(f"error listing iam user policies: {e}") @activity(config={"description": "Can be used to list AWS IAM users."}) - def list_users(self, _: dict) -> ListArtifact | ErrorArtifact: + def list_users(self) -> ListArtifact | ErrorArtifact: try: users = self.client.list_users() return ListArtifact([TextArtifact(str(u)) for u in users["Users"]]) diff --git a/griptape/tools/aws_s3/tool.py b/griptape/tools/aws_s3/tool.py index b352da2d5..fb0c2fdf2 100644 --- a/griptape/tools/aws_s3/tool.py +++ b/griptape/tools/aws_s3/tool.py @@ -79,7 +79,7 @@ def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact: return ErrorArtifact(f"error getting object acl: {e}") @activity(config={"description": "Can be used to list all AWS S3 buckets."}) - def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact: + def list_s3_buckets(self) -> ListArtifact | ErrorArtifact: try: buckets = self.client.list_buckets() diff --git a/griptape/tools/date_time/tool.py b/griptape/tools/date_time/tool.py index 5181dbe3e..dfdfcc578 100644 --- a/griptape/tools/date_time/tool.py +++ b/griptape/tools/date_time/tool.py @@ -9,7 +9,7 @@ class DateTimeTool(BaseTool): @activity(config={"description": "Can be used to return current date and time."}) - def get_current_datetime(self, _: dict) -> BaseArtifact: + def get_current_datetime(self) -> BaseArtifact: try: current_datetime = datetime.now() diff --git a/griptape/tools/web_search/tool.py b/griptape/tools/web_search/tool.py index 557c26a52..cbe4dcbf6 100644 --- a/griptape/tools/web_search/tool.py +++ b/griptape/tools/web_search/tool.py @@ -30,12 +30,10 @@ class WebSearchTool(BaseTool): ), }, ) - def search(self, props: dict) -> ListArtifact | ErrorArtifact: - values = props["values"] - query = values["query"] - extra_keys = {k: values[k] for k in values.keys() - {"query"}} + def search(self, values: dict) -> ListArtifact | ErrorArtifact: + query = values.pop("query") try: - return self.web_search_driver.search(query, **extra_keys) + return self.web_search_driver.search(query, **values) except Exception as e: return ErrorArtifact(f"Error searching '{query}' with {self.web_search_driver.__class__.__name__}: {e}") diff --git a/griptape/utils/decorators.py b/griptape/utils/decorators.py index 356f4eec0..f28972795 100644 --- a/griptape/utils/decorators.py +++ b/griptape/utils/decorators.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import inspect from typing import Any, Callable, Optional import schema @@ -14,7 +15,7 @@ ) -def activity(config: dict) -> Any: +def activity(config: dict) -> Any: # noqa: C901 validated_config = CONFIG_SCHEMA.validate(config) validated_config.update({k: v for k, v in config.items() if k not in validated_config}) @@ -24,8 +25,28 @@ def activity(config: dict) -> Any: def decorator(func: Callable) -> Any: @functools.wraps(func) - def wrapper(self: Any, *args, **kwargs) -> Any: - return func(self, *args, **kwargs) + def wrapper(self: Any, params: dict) -> Any: + func_params = inspect.signature(func).parameters.copy() + func_params.pop("self") + + kwarg_var = None + for param in func_params.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + kwarg_var = func_params.pop(param.name).name + break + + kwargs = {k: v for k, v in params.get("values", {}).items() if k in func_params or kwarg_var is not None} + + if "params" in func_params or kwarg_var is not None: + kwargs["params"] = params + if "values" in func_params or kwarg_var is not None: + kwargs["values"] = params.get("values") + + for param_name in func_params.keys(): # noqa: SIM118 + if param_name not in kwargs: + kwargs[param_name] = None + + return func(self, **kwargs) setattr(wrapper, "name", func.__name__) setattr(wrapper, "config", validated_config) diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index a9ee86c3c..4675c761e 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -24,8 +24,8 @@ class MockTool(BaseTool): "schema": Schema({Literal("test"): str}, description="Test input"), } ) - def test(self, value: dict) -> BaseArtifact: - return TextArtifact(f"ack {value['values']['test']}") + def test(self, **kwargs) -> BaseArtifact: + return TextArtifact(f"ack {kwargs['test']}") @activity( config={ @@ -33,8 +33,8 @@ def test(self, value: dict) -> BaseArtifact: "schema": Schema({Literal("test"): str}, description="Test input"), } ) - def test_error(self, value: dict) -> BaseArtifact: - return ErrorArtifact(f"error {value['values']['test']}") + def test_error(self, params: dict) -> BaseArtifact: + return ErrorArtifact(f"error {params['values']['test']}") @activity( config={ @@ -42,8 +42,8 @@ def test_error(self, value: dict) -> BaseArtifact: "schema": Schema({Literal("test"): str}, description="Test input"), } ) - def test_exception(self, value: dict) -> BaseArtifact: - raise Exception(f"error {value['values']['test']}") + def test_exception(self, params: dict) -> BaseArtifact: + raise Exception(f"error {params['values']['test']}") @activity( config={ @@ -51,11 +51,11 @@ def test_exception(self, value: dict) -> BaseArtifact: "schema": Schema({Literal("test"): str}, description="Test input"), } ) - def test_str_output(self, value: dict) -> str: - return f"ack {value['values']['test']}" + def test_str_output(self, params: dict) -> str: + return f"ack {params['values']['test']}" @activity(config={"description": "test description"}) - def test_no_schema(self, value: dict) -> str: + def test_no_schema(self) -> str: return "no schema" @activity( @@ -68,14 +68,14 @@ def test_callable_schema(self) -> TextArtifact: return TextArtifact("ack") @activity(config={"description": "test description"}) - def test_list_output(self, value: dict) -> ListArtifact: + def test_list_output(self) -> ListArtifact: return ListArtifact([TextArtifact("foo"), TextArtifact("bar")]) @activity( config={"description": "test description", "schema": Schema({Literal("test"): str}, description="Test input")} ) - def test_without_default_memory(self, value: dict) -> str: - return f"ack {value['values']['test']}" + def test_without_default_memory(self, params: dict) -> str: + return f"ack {params['values']['test']}" def foo(self) -> str: return "foo" From 901bd759ad682393f1d9cc8dab3ec4d03a7c16d3 Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 16 Oct 2024 19:09:32 -0500 Subject: [PATCH 2/5] new func --- griptape/utils/decorators.py | 48 +++++++++++++++++++----------------- 1 file changed, 26 insertions(+), 22 deletions(-) diff --git a/griptape/utils/decorators.py b/griptape/utils/decorators.py index f28972795..a99d5701c 100644 --- a/griptape/utils/decorators.py +++ b/griptape/utils/decorators.py @@ -15,7 +15,7 @@ ) -def activity(config: dict) -> Any: # noqa: C901 +def activity(config: dict) -> Any: validated_config = CONFIG_SCHEMA.validate(config) validated_config.update({k: v for k, v in config.items() if k not in validated_config}) @@ -26,27 +26,7 @@ def activity(config: dict) -> Any: # noqa: C901 def decorator(func: Callable) -> Any: @functools.wraps(func) def wrapper(self: Any, params: dict) -> Any: - func_params = inspect.signature(func).parameters.copy() - func_params.pop("self") - - kwarg_var = None - for param in func_params.values(): - if param.kind == inspect.Parameter.VAR_KEYWORD: - kwarg_var = func_params.pop(param.name).name - break - - kwargs = {k: v for k, v in params.get("values", {}).items() if k in func_params or kwarg_var is not None} - - if "params" in func_params or kwarg_var is not None: - kwargs["params"] = params - if "values" in func_params or kwarg_var is not None: - kwargs["values"] = params.get("values") - - for param_name in func_params.keys(): # noqa: SIM118 - if param_name not in kwargs: - kwargs[param_name] = None - - return func(self, **kwargs) + return func(self, **_build_kwargs(func, params)) setattr(wrapper, "name", func.__name__) setattr(wrapper, "config", validated_config) @@ -75,3 +55,27 @@ def lazy_attr(self: Any, value: Any) -> None: return lazy_attr return decorator + + +def _build_kwargs(func: Callable, params: dict) -> dict: + func_params = inspect.signature(func).parameters.copy() + func_params.pop("self") + + kwarg_var = None + for param in func_params.values(): + if param.kind == inspect.Parameter.VAR_KEYWORD: + kwarg_var = func_params.pop(param.name).name + break + + kwargs = {k: v for k, v in params.get("values", {}).items() if k in func_params or kwarg_var is not None} + + if "params" in func_params or kwarg_var is not None: + kwargs["params"] = params + if "values" in func_params or kwarg_var is not None: + kwargs["values"] = params.get("values") + + for param_name in func_params: + if param_name not in kwargs: + kwargs[param_name] = None + + return kwargs From b8741c461fb98ee1324f48dd4ec719d3f26abb7f Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 16 Oct 2024 19:20:40 -0500 Subject: [PATCH 3/5] tests, comments, ruff --- griptape/utils/decorators.py | 7 +++++++ tests/mocks/mock_tool/tool.py | 19 +++++++++++++++++-- tests/unit/tools/test_base_tool.py | 11 +++++++++++ 3 files changed, 35 insertions(+), 2 deletions(-) diff --git a/griptape/utils/decorators.py b/griptape/utils/decorators.py index a99d5701c..3eef6d8d0 100644 --- a/griptape/utils/decorators.py +++ b/griptape/utils/decorators.py @@ -63,17 +63,24 @@ def _build_kwargs(func: Callable, params: dict) -> dict: kwarg_var = None for param in func_params.values(): + # if there is a **kwargs parameter, we can safely + # pass all the params to the function if param.kind == inspect.Parameter.VAR_KEYWORD: kwarg_var = func_params.pop(param.name).name break + # only pass the values that are in the function signature + # or if there is a **kwargs parameter, pass all the values kwargs = {k: v for k, v in params.get("values", {}).items() if k in func_params or kwarg_var is not None} + # add 'params' and 'values' if they are in the signature + # or if there is a **kwargs parameter if "params" in func_params or kwarg_var is not None: kwargs["params"] = params if "values" in func_params or kwarg_var is not None: kwargs["values"] = params.get("values") + # set any missing parameters to None for param_name in func_params: if param_name not in kwargs: kwargs[param_name] = None diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index 4675c761e..1162ea1ee 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -24,8 +24,8 @@ class MockTool(BaseTool): "schema": Schema({Literal("test"): str}, description="Test input"), } ) - def test(self, **kwargs) -> BaseArtifact: - return TextArtifact(f"ack {kwargs['test']}") + def test(self, test: str) -> BaseArtifact: + return TextArtifact(f"ack {test}") @activity( config={ @@ -77,6 +77,21 @@ def test_list_output(self) -> ListArtifact: def test_without_default_memory(self, params: dict) -> str: return f"ack {params['values']['test']}" + @activity( + config={"description": "test description", "schema": Schema({Literal("test"): str}, description="Test input")} + ) + def test_with_none_kwarg(self, not_test: None) -> bool: + return not_test is None + + @activity( + config={ + "description": "test description", + "schema": Schema({Literal("test_kwarg"): str}, description="Test input"), + } + ) + def test_with_kwargs_var(self, **kwargs) -> bool: + return "test_kwarg" in kwargs + def foo(self) -> str: return "foo" diff --git a/tests/unit/tools/test_base_tool.py b/tests/unit/tools/test_base_tool.py index 9f28acf02..36115e8a2 100644 --- a/tests/unit/tools/test_base_tool.py +++ b/tests/unit/tools/test_base_tool.py @@ -308,3 +308,14 @@ def test_from_dict(self, tool): assert isinstance(deserialized_tool, BaseTool) assert deserialized_tool.execute(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar" + + def test_method_kwarg_injection(self, tool): + tool = MockTool() + + assert tool.test_with_none_kwarg() is True + + def test_method_kwargs_var_injection(self, tool): + tool = MockTool() + + assert tool.test_with_kwargs_var(test_kwarg="foo") is True + assert tool.test_with_kwargs_var() is False From 71059842859aad69ac4d7775978755f66284955a Mon Sep 17 00:00:00 2001 From: matt Date: Wed, 16 Oct 2024 19:51:00 -0500 Subject: [PATCH 4/5] update tests --- tests/mocks/mock_tool/tool.py | 15 ----------- tests/mocks/mock_tool_kwargs/__init__.py | 0 tests/mocks/mock_tool_kwargs/requirements.txt | 0 tests/mocks/mock_tool_kwargs/tool.py | 25 +++++++++++++++++++ tests/unit/tools/test_base_tool.py | 12 +++------ 5 files changed, 29 insertions(+), 23 deletions(-) create mode 100644 tests/mocks/mock_tool_kwargs/__init__.py create mode 100644 tests/mocks/mock_tool_kwargs/requirements.txt create mode 100644 tests/mocks/mock_tool_kwargs/tool.py diff --git a/tests/mocks/mock_tool/tool.py b/tests/mocks/mock_tool/tool.py index 1162ea1ee..c7bacc3b2 100644 --- a/tests/mocks/mock_tool/tool.py +++ b/tests/mocks/mock_tool/tool.py @@ -77,21 +77,6 @@ def test_list_output(self) -> ListArtifact: def test_without_default_memory(self, params: dict) -> str: return f"ack {params['values']['test']}" - @activity( - config={"description": "test description", "schema": Schema({Literal("test"): str}, description="Test input")} - ) - def test_with_none_kwarg(self, not_test: None) -> bool: - return not_test is None - - @activity( - config={ - "description": "test description", - "schema": Schema({Literal("test_kwarg"): str}, description="Test input"), - } - ) - def test_with_kwargs_var(self, **kwargs) -> bool: - return "test_kwarg" in kwargs - def foo(self) -> str: return "foo" diff --git a/tests/mocks/mock_tool_kwargs/__init__.py b/tests/mocks/mock_tool_kwargs/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mocks/mock_tool_kwargs/requirements.txt b/tests/mocks/mock_tool_kwargs/requirements.txt new file mode 100644 index 000000000..e69de29bb diff --git a/tests/mocks/mock_tool_kwargs/tool.py b/tests/mocks/mock_tool_kwargs/tool.py new file mode 100644 index 000000000..cd95f9c75 --- /dev/null +++ b/tests/mocks/mock_tool_kwargs/tool.py @@ -0,0 +1,25 @@ +from attrs import define +from schema import Literal, Schema + +from griptape.tools import BaseTool +from griptape.utils.decorators import activity + + +@define +class MockToolKwargs(BaseTool): + @activity( + config={ + "description": "test description", + "schema": Schema({Literal("test_kwarg"): str}, description="Test input"), + } + ) + def test_with_kwargs(self, params: dict, test_kwarg: str, test_kwarg_none: None, **kwargs) -> str: + if test_kwarg_none is not None: + raise ValueError("test_kwarg_none should be None") + if "test_kwarg_kwargs" not in kwargs: + raise ValueError("test_kwarg_kwargs not in kwargs") + if "values" not in kwargs: + raise ValueError("values not in params") + if "test_kwarg" not in params["values"]: + raise ValueError("test_kwarg not in params") + return f"ack {test_kwarg}" diff --git a/tests/unit/tools/test_base_tool.py b/tests/unit/tools/test_base_tool.py index 36115e8a2..5ac3849d5 100644 --- a/tests/unit/tools/test_base_tool.py +++ b/tests/unit/tools/test_base_tool.py @@ -8,6 +8,7 @@ from griptape.tasks import ActionsSubtask, ToolkitTask from griptape.tools import BaseTool from tests.mocks.mock_tool.tool import MockTool +from tests.mocks.mock_tool_kwargs.tool import MockToolKwargs from tests.utils import defaults @@ -309,13 +310,8 @@ def test_from_dict(self, tool): assert deserialized_tool.execute(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar" - def test_method_kwarg_injection(self, tool): - tool = MockTool() - - assert tool.test_with_none_kwarg() is True - def test_method_kwargs_var_injection(self, tool): - tool = MockTool() + tool = MockToolKwargs() - assert tool.test_with_kwargs_var(test_kwarg="foo") is True - assert tool.test_with_kwargs_var() is False + params = {"values": {"test_kwarg": "foo", "test_kwarg_kwargs": "bar"}} + assert tool.test_with_kwargs(params) == "ack foo" From 35c5f48bae479b98392d5123fe07eacbad57eab4 Mon Sep 17 00:00:00 2001 From: matt Date: Thu, 17 Oct 2024 11:56:45 -0500 Subject: [PATCH 5/5] log --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 37afa4b41..212712194 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `Chat` output now uses `Rich.print` by default. - `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`. - Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory` +- `@activity` decorated functions can now accept kwargs that are defined in the activity schema. ### Fixed