Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Kwarg injection with @activity decorated functions #1265

Merged
merged 5 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `Chat` output now uses `Rich.print` by default.
- `Chat.output_fn`'s now takes an optional kwarg parameter, `stream`.
- Implemented `SerializableMixin` in `Structure`, `BaseTask`, `BaseTool`, and `TaskMemory`
- `@activity` decorated functions can now accept kwargs that are defined in the activity schema.

### Fixed

Expand Down
11 changes: 9 additions & 2 deletions docs/griptape-tools/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,19 @@ Tools give the LLM abilities to invoke outside APIs, reference data sets, and ge

Griptape tools are special Python classes that LLMs can use to accomplish specific goals. Here is an example custom tool for generating a random number:

A tool can have many "activities" as denoted by the `@activity` decorator. Each activity has a description (used to provide context to the LLM), and the input schema that the LLM must follow in order to use the tool.

When a function is decorated with `@activity`, the decorator injects keyword arguments into the function according to the schema. There are also two Griptape-provided keyword arguments: `params: dict` and `values: dict`.

!!! info
If your schema defines any parameters named `params` or `values`, they will be overwritten by the Griptape-provided arguments.

In the following example, all `@activity` decorated functions will result in the same value, but the method signature is defined in different ways.

```python
--8<-- "docs/griptape-tools/src/index_1.py"
```

A tool can have many "activities" as denoted by the `@activity` decorator. Each activity has a description (used to provide context to the LLM), and the input schema that the LLM must follow in order to use the tool.

Output artifacts from all tool activities (except for `InfoArtifact` and `ErrorArtifact`) go to short-term `TaskMemory`. To disable that behavior set the `off_prompt` tool parameter to `False`:

We provide a set of official Griptape Tools for accessing and processing data. You can also [build your own tools](./custom-tools/index.md).
36 changes: 36 additions & 0 deletions docs/griptape-tools/src/index_1.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from __future__ import annotations

import random
import typing

from schema import Literal, Optional, Schema

Expand All @@ -19,5 +22,38 @@ class RandomNumberGenerator(BaseTool):
def generate(self, params: dict) -> TextArtifact:
return TextArtifact(str(round(random.random(), params["values"].get("decimals"))))

@activity(
config={
"description": "Can be used to generate random numbers",
"schema": Schema(
{Optional(Literal("decimals", description="Number of decimals to round the random number to")): int}
),
}
)
def generate_with_decimals(self, decimals: typing.Optional[int]) -> TextArtifact:
return TextArtifact(str(round(random.random(), decimals)))

@activity(
config={
"description": "Can be used to generate random numbers",
"schema": Schema(
{Optional(Literal("decimals", description="Number of decimals to round the random number to")): int}
),
}
)
def generate_with_values(self, values: dict) -> TextArtifact:
return TextArtifact(str(round(random.random(), values.get("decimals"))))

@activity(
config={
"description": "Can be used to generate random numbers",
"schema": Schema(
{Optional(Literal("decimals", description="Number of decimals to round the random number to")): int}
),
}
)
def generate_with_kwargs(self, **kwargs) -> TextArtifact:
return TextArtifact(str(round(random.random(), kwargs.get("decimals"))))


RandomNumberGenerator()
4 changes: 2 additions & 2 deletions griptape/tools/aws_iam/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def get_user_policy(self, params: dict) -> TextArtifact | ErrorArtifact:
return ErrorArtifact(f"error returning policy document: {e}")

@activity(config={"description": "Can be used to list AWS MFA Devices"})
def list_mfa_devices(self, _: dict) -> ListArtifact | ErrorArtifact:
def list_mfa_devices(self) -> ListArtifact | ErrorArtifact:
try:
devices = self.client.list_mfa_devices()
return ListArtifact([TextArtifact(str(d)) for d in devices["MFADevices"]])
Expand Down Expand Up @@ -76,7 +76,7 @@ def list_user_policies(self, params: dict) -> ListArtifact | ErrorArtifact:
return ErrorArtifact(f"error listing iam user policies: {e}")

@activity(config={"description": "Can be used to list AWS IAM users."})
def list_users(self, _: dict) -> ListArtifact | ErrorArtifact:
def list_users(self) -> ListArtifact | ErrorArtifact:
try:
users = self.client.list_users()
return ListArtifact([TextArtifact(str(u)) for u in users["Users"]])
Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/aws_s3/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def get_object_acl(self, params: dict) -> TextArtifact | ErrorArtifact:
return ErrorArtifact(f"error getting object acl: {e}")

@activity(config={"description": "Can be used to list all AWS S3 buckets."})
def list_s3_buckets(self, _: dict) -> ListArtifact | ErrorArtifact:
def list_s3_buckets(self) -> ListArtifact | ErrorArtifact:
try:
buckets = self.client.list_buckets()

Expand Down
2 changes: 1 addition & 1 deletion griptape/tools/date_time/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

class DateTimeTool(BaseTool):
@activity(config={"description": "Can be used to return current date and time."})
def get_current_datetime(self, _: dict) -> BaseArtifact:
def get_current_datetime(self) -> BaseArtifact:
try:
current_datetime = datetime.now()

Expand Down
8 changes: 3 additions & 5 deletions griptape/tools/web_search/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ class WebSearchTool(BaseTool):
),
},
)
def search(self, props: dict) -> ListArtifact | ErrorArtifact:
values = props["values"]
query = values["query"]
extra_keys = {k: values[k] for k in values.keys() - {"query"}}
def search(self, values: dict) -> ListArtifact | ErrorArtifact:
query = values.pop("query")

try:
return self.web_search_driver.search(query, **extra_keys)
return self.web_search_driver.search(query, **values)
except Exception as e:
return ErrorArtifact(f"Error searching '{query}' with {self.web_search_driver.__class__.__name__}: {e}")
36 changes: 34 additions & 2 deletions griptape/utils/decorators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import functools
import inspect
from typing import Any, Callable, Optional

import schema
Expand All @@ -24,8 +25,8 @@ def activity(config: dict) -> Any:

def decorator(func: Callable) -> Any:
@functools.wraps(func)
def wrapper(self: Any, *args, **kwargs) -> Any:
return func(self, *args, **kwargs)
def wrapper(self: Any, params: dict) -> Any:
return func(self, **_build_kwargs(func, params))

setattr(wrapper, "name", func.__name__)
setattr(wrapper, "config", validated_config)
Expand Down Expand Up @@ -54,3 +55,34 @@ def lazy_attr(self: Any, value: Any) -> None:
return lazy_attr

return decorator


def _build_kwargs(func: Callable, params: dict) -> dict:
func_params = inspect.signature(func).parameters.copy()
func_params.pop("self")

kwarg_var = None
for param in func_params.values():
# if there is a **kwargs parameter, we can safely
# pass all the params to the function
if param.kind == inspect.Parameter.VAR_KEYWORD:
kwarg_var = func_params.pop(param.name).name
break

# only pass the values that are in the function signature
# or if there is a **kwargs parameter, pass all the values
kwargs = {k: v for k, v in params.get("values", {}).items() if k in func_params or kwarg_var is not None}

# add 'params' and 'values' if they are in the signature
# or if there is a **kwargs parameter
if "params" in func_params or kwarg_var is not None:
kwargs["params"] = params
if "values" in func_params or kwarg_var is not None:
kwargs["values"] = params.get("values")

# set any missing parameters to None
for param_name in func_params:
if param_name not in kwargs:
kwargs[param_name] = None

return kwargs
24 changes: 12 additions & 12 deletions tests/mocks/mock_tool/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,38 @@ class MockTool(BaseTool):
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test(self, value: dict) -> BaseArtifact:
return TextArtifact(f"ack {value['values']['test']}")
def test(self, test: str) -> BaseArtifact:
return TextArtifact(f"ack {test}")

@activity(
config={
"description": "test description: {{ _self.foo() }}",
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test_error(self, value: dict) -> BaseArtifact:
return ErrorArtifact(f"error {value['values']['test']}")
def test_error(self, params: dict) -> BaseArtifact:
return ErrorArtifact(f"error {params['values']['test']}")

@activity(
config={
"description": "test description: {{ _self.foo() }}",
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test_exception(self, value: dict) -> BaseArtifact:
raise Exception(f"error {value['values']['test']}")
def test_exception(self, params: dict) -> BaseArtifact:
raise Exception(f"error {params['values']['test']}")

@activity(
config={
"description": "test description: {{ _self.foo() }}",
"schema": Schema({Literal("test"): str}, description="Test input"),
}
)
def test_str_output(self, value: dict) -> str:
return f"ack {value['values']['test']}"
def test_str_output(self, params: dict) -> str:
return f"ack {params['values']['test']}"

@activity(config={"description": "test description"})
def test_no_schema(self, value: dict) -> str:
def test_no_schema(self) -> str:
return "no schema"

@activity(
Expand All @@ -68,14 +68,14 @@ def test_callable_schema(self) -> TextArtifact:
return TextArtifact("ack")

@activity(config={"description": "test description"})
def test_list_output(self, value: dict) -> ListArtifact:
def test_list_output(self) -> ListArtifact:
return ListArtifact([TextArtifact("foo"), TextArtifact("bar")])

@activity(
config={"description": "test description", "schema": Schema({Literal("test"): str}, description="Test input")}
)
def test_without_default_memory(self, value: dict) -> str:
return f"ack {value['values']['test']}"
def test_without_default_memory(self, params: dict) -> str:
return f"ack {params['values']['test']}"

def foo(self) -> str:
return "foo"
Expand Down
Empty file.
Empty file.
25 changes: 25 additions & 0 deletions tests/mocks/mock_tool_kwargs/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from attrs import define
from schema import Literal, Schema

from griptape.tools import BaseTool
from griptape.utils.decorators import activity


@define
class MockToolKwargs(BaseTool):
@activity(
config={
"description": "test description",
"schema": Schema({Literal("test_kwarg"): str}, description="Test input"),
}
)
def test_with_kwargs(self, params: dict, test_kwarg: str, test_kwarg_none: None, **kwargs) -> str:
if test_kwarg_none is not None:
raise ValueError("test_kwarg_none should be None")
if "test_kwarg_kwargs" not in kwargs:
raise ValueError("test_kwarg_kwargs not in kwargs")
if "values" not in kwargs:
raise ValueError("values not in params")
if "test_kwarg" not in params["values"]:
raise ValueError("test_kwarg not in params")
return f"ack {test_kwarg}"
7 changes: 7 additions & 0 deletions tests/unit/tools/test_base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from griptape.tasks import ActionsSubtask, ToolkitTask
from griptape.tools import BaseTool
from tests.mocks.mock_tool.tool import MockTool
from tests.mocks.mock_tool_kwargs.tool import MockToolKwargs
from tests.utils import defaults


Expand Down Expand Up @@ -308,3 +309,9 @@ def test_from_dict(self, tool):
assert isinstance(deserialized_tool, BaseTool)

assert deserialized_tool.execute(tool.test_list_output, ActionsSubtask("foo"), action).to_text() == "foo\n\nbar"

def test_method_kwargs_var_injection(self, tool):
tool = MockToolKwargs()

params = {"values": {"test_kwarg": "foo", "test_kwarg_kwargs": "bar"}}
assert tool.test_with_kwargs(params) == "ack foo"
Loading