Skip to content

Commit

Permalink
Add GriptapeCloudToolDriver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Nov 20, 2024
1 parent 3c4b414 commit abc749b
Show file tree
Hide file tree
Showing 15 changed files with 608 additions and 4 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `AssistantTask` for running Assistants in Structures.
- `GriptapeCloudAssistantDriver` for interacting with Griptape Cloud's Assistant API.
- `OpenAiAssistantDriver` for interacting with OpenAI's Assistant API.
- New class of Drivers, Tool Drivers, for augmenting Tools with additional functionality.
- `GriptapeCloudToolDriver` for interacting with Griptape Cloud's Tool API.
- `BaseTool.tool_driver` for setting a Tool Driver on a Tool. The Tool will call `tool_driver.initialize_tool` after Tool initialization.

### Changed

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Drivers facilitate interactions with external resources and services:
- 📈 **Observability Drivers** send trace and event data to observability platforms.
- 📜 **Ruleset Drivers** load and apply rulesets from external sources.
- 🗂️ **File Manager Drivers** handle file operations on local and remote storage.
- 🔨 **Tool Drivers** augment Tools with additional functionality.

### 🚂 Engines

Expand Down
18 changes: 18 additions & 0 deletions docs/griptape-framework/drivers/src/griptape_cloud_tool_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

from griptape.drivers import GriptapeCloudToolDriver
from griptape.structures import Agent
from griptape.tools import CalculatorTool

tool = CalculatorTool()
driver = GriptapeCloudToolDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
tool_id=os.environ["GT_CLOUD_TOOL_ID"],
)
driver.initialize_tool(tool)

agent = Agent(
tools=[tool],
)

agent.run("What is 2 ^ 2 + 12 and 5 * 145369?")
18 changes: 18 additions & 0 deletions docs/griptape-framework/drivers/src/tool_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

from griptape.drivers import GriptapeCloudToolDriver
from griptape.structures import Agent
from griptape.tools import CalculatorTool

agent = Agent(
tools=[
CalculatorTool(
tool_driver=GriptapeCloudToolDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
tool_id=os.environ["GT_CLOUD_TOOL_ID"],
)
)
],
)

agent.run("What is 2 ^ 2 + 12 and 5 * 145369?")
26 changes: 26 additions & 0 deletions docs/griptape-framework/drivers/tool-drivers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
---
search:
boost: 2
---

## Overview

Tool Drivers can be used to augment [Tools](../../griptape-tools/index.md) with additional functionality.

You can use Tool Drivers with Tools:

```python
--8<-- "docs/griptape-framework/drivers/src/tool_driver.py"
```

In this example, the `CalculatorTool`'s schema is loaded from Griptape Cloud. Additionally, when the `CalculatorTool` is executed, it is run in the Cloud rather than locally.

## Tool Drivers

### Griptape Cloud

The [GriptapeCloudToolDriver](../../reference/griptape/drivers/tool/griptape_cloud_tool_driver.md) uses Griptape Cloud's hosted Tools service to modify the Tool with the hosted Tool's schema and execution logic.

```python
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver.py"
```
4 changes: 4 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
from .assistant.base_assistant_driver import BaseAssistantDriver
from .assistant.griptape_cloud_assistant_driver import GriptapeCloudAssistantDriver
from .assistant.openai_assistant_driver import OpenAiAssistantDriver
from .tool.base_tool_driver import BaseToolDriver
from .tool.griptape_cloud_tool_driver import GriptapeCloudToolDriver

__all__ = [
"BasePromptDriver",
Expand Down Expand Up @@ -240,4 +242,6 @@
"BaseAssistantDriver",
"GriptapeCloudAssistantDriver",
"OpenAiAssistantDriver",
"BaseToolDriver",
"GriptapeCloudToolDriver",
]
Empty file.
23 changes: 23 additions & 0 deletions griptape/drivers/tool/base_tool_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from attrs import define

if TYPE_CHECKING:
from griptape.tools.base_tool import BaseTool


@define
class BaseToolDriver(ABC):
@abstractmethod
def initialize_tool(self, tool: BaseTool) -> None:
"""Performs some initialization operations on the Tool. This method is intended to mutate the Tool object.
Args:
tool: Tool to initialize.
"""

...
127 changes: 127 additions & 0 deletions griptape/drivers/tool/griptape_cloud_tool_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
from __future__ import annotations

import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urljoin

import requests
from attrs import Factory, define, field
from schema import Literal, Optional, Schema

from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.drivers.tool.base_tool_driver import BaseToolDriver
from griptape.utils.decorators import activity

if TYPE_CHECKING:
from griptape.tools.base_tool import BaseTool


@define
class GriptapeCloudToolDriver(BaseToolDriver):
base_url: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")))
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
tool_id: str = field(kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
init=False,
)

def initialize_tool(self, tool: BaseTool) -> None:
schema = self._get_schema()
tool_name, activity_schemas = self._parse_schema(schema)

if tool.name == tool.__class__.__name__:
tool.name = tool_name

for activity_name, (description, activity_schema) in activity_schemas.items():
activity_handler = self._create_activity_handler(activity_name, description, activity_schema)

setattr(tool, activity_name, MethodType(activity_handler, self))

def _get_schema(self) -> dict:
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/openapi")
response = requests.get(url, headers=self.headers)

response.raise_for_status()

return response.json()

def _parse_schema(self, schema: dict) -> tuple[str, dict[str, tuple[str, Schema]]]:
"""Parses an openapi schema into a dictionary of activity names and their respective descriptions + schemas."""
activities = {}

name = schema.get("info", {}).get("title")

for path, path_info in schema.get("paths", {}).items():
if not path.startswith("/activities"):
continue
for method, method_info in path_info.items():
if "post" in method.lower():
activity_name = method_info["operationId"]
description = method_info.get("description", "")

activity_schema = self.__extract_schema_from_ref(
schema,
method_info.get("requestBody", {})
.get("content", {})
.get("application/json", {})
.get("schema", {}),
)

activities[activity_name] = (description, activity_schema)

return name, activities

def __extract_schema_from_ref(self, schema: dict, schema_ref: dict) -> Schema:
"""Extracts a schema from a $ref if present, resolving it into native schema properties."""
if "$ref" in schema_ref:
# Resolve the reference and retrieve the schema data
ref_path = schema_ref["$ref"].split("/")[-1]
schema_data = schema["components"]["schemas"].get(ref_path, {})
else:
# Use the provided schema directly if no $ref is found
schema_data = schema_ref

# Convert the schema_data dictionary into a Schema with its properties
properties = {}
for prop, prop_info in schema_data.get("properties", {}).items():
prop_type = prop_info.get("type", "string")
prop_description = prop_info.get("description", "")
schema_prop = Literal(prop, description=prop_description)
is_optional = prop not in schema_data.get("required", [])

if is_optional:
schema_prop = Optional(schema_prop)

properties[schema_prop] = self._map_openapi_type_to_python(prop_type)

return Schema(properties)

def _map_openapi_type_to_python(self, openapi_type: str) -> type:
"""Maps OpenAPI types to native Python types."""
type_mapping = {"string": str, "integer": int, "boolean": bool, "number": float, "array": list, "object": dict}

return type_mapping.get(openapi_type, str)

def _create_activity_handler(self, activity_name: str, description: str, activity_schema: Schema) -> Callable:
"""Creates an activity handler method for the tool."""

@activity(config={"name": activity_name, "description": description, "schema": activity_schema})
def activity_handler(_: BaseTool, values: dict) -> Any:
return self._run_activity(activity_name, values)

return activity_handler

def _run_activity(self, activity_name: str, params: dict) -> BaseArtifact:
"""Runs an activity on the tool with the provided parameters."""
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/activities/{activity_name}")

response = requests.post(url, json=params, headers=self.headers)

response.raise_for_status()

try:
return BaseArtifact.from_dict(response.json())
except ValueError:
return TextArtifact(response.text)
2 changes: 2 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
BasePromptDriver,
BaseRulesetDriver,
BaseTextToSpeechDriver,
BaseToolDriver,
BaseVectorStoreDriver,
)
from griptape.events import EventListener
Expand All @@ -192,6 +193,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseConversationMemoryDriver": BaseConversationMemoryDriver,
"BaseRulesetDriver": BaseRulesetDriver,
"BaseImageGenerationDriver": BaseImageGenerationDriver,
"BaseToolDriver": BaseToolDriver,
"BaseArtifact": BaseArtifact,
"PromptStack": PromptStack,
"EventListener": EventListener,
Expand Down
5 changes: 5 additions & 0 deletions griptape/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if TYPE_CHECKING:
from griptape.common import ToolAction
from griptape.drivers import BaseToolDriver
from griptape.memory import TaskMemory
from griptape.tasks import ActionsSubtask

Expand Down Expand Up @@ -56,6 +57,7 @@ class BaseTool(ActivityMixin, SerializableMixin, RunnableMixin["BaseTool"], ABC)
dependencies_install_directory: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
verbose: bool = field(default=False, kw_only=True, metadata={"serializable": True})
off_prompt: bool = field(default=False, kw_only=True, metadata={"serializable": True})
tool_driver: BaseToolDriver = field(default=None, kw_only=True)

def __attrs_post_init__(self) -> None:
if (
Expand All @@ -65,6 +67,9 @@ def __attrs_post_init__(self) -> None:
):
self.install_dependencies(os.environ.copy())

if self.tool_driver is not None:
self.tool_driver.initialize_tool(self)

@output_memory.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_output_memory(self, _: Attribute, output_memory: dict[str, Optional[list[TaskMemory]]]) -> None:
if output_memory:
Expand Down
9 changes: 5 additions & 4 deletions griptape/utils/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,14 @@

import functools
import inspect
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, OrderedDict, cast

import schema
from schema import Schema

CONFIG_SCHEMA = Schema(
{
schema.Optional("name"): str,
"description": str,
schema.Optional("schema"): lambda data: isinstance(data, (Schema, Callable)),
}
Expand All @@ -28,7 +29,7 @@ def decorator(func: Callable) -> Any:
def wrapper(self: Any, params: dict) -> Any:
return func(self, **_build_kwargs(func, params))

setattr(wrapper, "name", func.__name__)
setattr(wrapper, "name", validated_config.get("name", func.__name__))
setattr(wrapper, "config", validated_config)
setattr(wrapper, "is_activity", True)

Expand Down Expand Up @@ -58,8 +59,8 @@ def lazy_attr(self: Any, value: Any) -> None:


def _build_kwargs(func: Callable, params: dict) -> dict:
func_params = inspect.signature(func).parameters.copy()
func_params.pop("self")
func_params = cast(OrderedDict, inspect.signature(func).parameters.copy())
func_params.popitem(last=False)

kwarg_var = None
for param in func_params.values():
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ nav:
- Observability Drivers: "griptape-framework/drivers/observability-drivers.md"
- Ruleset Drivers: "griptape-framework/drivers/ruleset-drivers.md"
- File Manager Drivers: "griptape-framework/drivers/file-manager-drivers.md"
- Tool Drivers: "griptape-framework/drivers/tool-drivers.md"
- Data:
- Overview: "griptape-framework/data/index.md"
- Artifacts: "griptape-framework/data/artifacts.md"
Expand Down
Empty file.
Loading

0 comments on commit abc749b

Please sign in to comment.