Skip to content

Commit

Permalink
Add GriptapeCloudToolDriver (#1367)
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter authored Nov 20, 2024
1 parent 3c4b414 commit 74a2937
Show file tree
Hide file tree
Showing 17 changed files with 673 additions and 5 deletions.
1 change: 1 addition & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ jobs:
GT_CLOUD_STRUCTURE_ID: ${{ vars.INTEG_GT_CLOUD_STRUCTURE_ID }}
GT_CLOUD_STRUCTURE_RUN_ID: ${{ vars.INTEG_GT_CLOUD_STRUCTURE_RUN_ID }}
GT_CLOUD_THREAD_ID: ${{ vars.INTEG_GT_CLOUD_THREAD_ID }}
GT_CLOUD_TOOL_ID: ${{ vars.INTEG_GT_CLOUD_TOOL_ID }}
HUGGINGFACE_HUB_ACCESS_TOKEN: ${{ secrets.INTEG_HUGGINGFACE_HUB_ACCESS_TOKEN }}
LEONARDO_API_KEY: ${{ secrets.INTEG_LEONARDO_API_KEY }}
LEONARDO_MODEL_ID: ${{ secrets.INTEG_LEONARDO_MODEL_ID }}
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `AssistantTask` for running Assistants in Structures.
- `GriptapeCloudAssistantDriver` for interacting with Griptape Cloud's Assistant API.
- `OpenAiAssistantDriver` for interacting with OpenAI's Assistant API.
- New class of Drivers, Tool Drivers, for augmenting Tools with additional functionality.
- `GriptapeCloudToolDriver` for interacting with Griptape Cloud's Tool API.
- `BaseTool.tool_driver` for setting a Tool Driver on a Tool. The Tool will call `tool_driver.initialize_tool` after Tool initialization.

### Changed

Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ Drivers facilitate interactions with external resources and services:
- 📈 **Observability Drivers** send trace and event data to observability platforms.
- 📜 **Ruleset Drivers** load and apply rulesets from external sources.
- 🗂️ **File Manager Drivers** handle file operations on local and remote storage.
- 🔨 **Tool Drivers** augment Tools with additional functionality.

### 🚂 Engines

Expand Down Expand Up @@ -217,7 +218,7 @@ Run `make check` to run all code checks locally.

### Griptape Extensions

Griptape's extensibility allows anyone to develop and distribute functionality independently.
Griptape's extensibility allows anyone to develop and distribute functionality independently.
All new integrations, including Tools, Drivers, Tasks, etc., should initially be developed as extensions and then can be upstreamed into Griptape core if discussed and approved.

The [Griptape Extension Template](https://github.com/griptape-ai/griptape-extension-template) provides the recommended structure, step-by-step instructions, basic automation, and usage examples for new integrations.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

from griptape.drivers import GriptapeCloudToolDriver
from griptape.structures import Agent
from griptape.tools import BaseTool


class RandomNumberGeneratorTool(BaseTool): ...


tool = RandomNumberGeneratorTool()
driver = GriptapeCloudToolDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
tool_id=os.environ["GT_CLOUD_TOOL_ID"],
)
driver.initialize_tool(tool)

agent = Agent(
tools=[tool],
)

agent.run("Generate a random number to 5 decimal places.")
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import os
import random

import schema

from griptape.artifacts.text_artifact import TextArtifact
from griptape.drivers import GriptapeCloudToolDriver
from griptape.structures import Agent
from griptape.tools.base_tool import BaseTool
from griptape.utils.decorators import activity


class ExtendedRandomNumberGeneratorTool(BaseTool):
@activity(
config={
"description": "Generates a random number in a range.",
"schema": schema.Schema(
{
"min": schema.Or(float, int),
"max": schema.Or(float, int),
}
),
}
)
def rand_range(self, values: dict) -> TextArtifact:
return TextArtifact(random.uniform(values["min"], values["max"]))


agent = Agent(
tools=[
ExtendedRandomNumberGeneratorTool(
tool_driver=GriptapeCloudToolDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
tool_id=os.environ["GT_CLOUD_TOOL_ID"],
)
)
],
)

agent.run("Generate a random number between 1 and 100.")
22 changes: 22 additions & 0 deletions docs/griptape-framework/drivers/src/tool_driver_1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import os

from griptape.drivers import GriptapeCloudToolDriver
from griptape.structures import Agent
from griptape.tools import BaseTool


class RandomNumberGeneratorTool(BaseTool): ...


agent = Agent(
tools=[
RandomNumberGeneratorTool(
tool_driver=GriptapeCloudToolDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
tool_id=os.environ["GT_CLOUD_TOOL_ID"],
)
)
],
)

agent.run("Generate a random number to 5 decimal places.")
32 changes: 32 additions & 0 deletions docs/griptape-framework/drivers/tool-drivers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
---
search:
boost: 2
---

## Overview

Tool Drivers can be used to augment [Tools](../../griptape-tools/index.md) with additional functionality.

You can use Tool Drivers with Tools:

```python
--8<-- "docs/griptape-framework/drivers/src/tool_driver_1.py"
```

In this example, the `RandomNumberGeneratorTool`'s schema is loaded from Griptape Cloud. Additionally, when the `RandomNumberGeneratorTool` is executed, it is run in the Cloud rather than locally.

## Tool Drivers

### Griptape Cloud

The [GriptapeCloudToolDriver](../../reference/griptape/drivers/tool/griptape_cloud_tool_driver.md) uses Griptape Cloud's hosted Tools service to modify the Tool with the hosted Tool's schema and execution logic.

```python
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_1.py"
```

The Driver will only overwrite activities present in the hosted Tool, meaning you can mix and match local and hosted activities.

```python
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_2.py"
```
4 changes: 4 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@
from .assistant.base_assistant_driver import BaseAssistantDriver
from .assistant.griptape_cloud_assistant_driver import GriptapeCloudAssistantDriver
from .assistant.openai_assistant_driver import OpenAiAssistantDriver
from .tool.base_tool_driver import BaseToolDriver
from .tool.griptape_cloud_tool_driver import GriptapeCloudToolDriver

__all__ = [
"BasePromptDriver",
Expand Down Expand Up @@ -240,4 +242,6 @@
"BaseAssistantDriver",
"GriptapeCloudAssistantDriver",
"OpenAiAssistantDriver",
"BaseToolDriver",
"GriptapeCloudToolDriver",
]
Empty file.
23 changes: 23 additions & 0 deletions griptape/drivers/tool/base_tool_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import TYPE_CHECKING

from attrs import define

if TYPE_CHECKING:
from griptape.tools.base_tool import BaseTool


@define
class BaseToolDriver(ABC):
@abstractmethod
def initialize_tool(self, tool: BaseTool) -> None:
"""Performs some initialization operations on the Tool. This method is intended to mutate the Tool object.
Args:
tool: Tool to initialize.
"""

...
136 changes: 136 additions & 0 deletions griptape/drivers/tool/griptape_cloud_tool_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from __future__ import annotations

import os
from types import MethodType
from typing import TYPE_CHECKING, Any, Callable
from urllib.parse import urljoin

import requests
from attrs import Factory, define, field
from schema import Literal, Optional, Schema

from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.drivers.tool.base_tool_driver import BaseToolDriver
from griptape.utils.decorators import activity

if TYPE_CHECKING:
from griptape.tools.base_tool import BaseTool


@define
class GriptapeCloudToolDriver(BaseToolDriver):
"""Driver for interacting with tools hosted on the Griptape Cloud.
Attributes:
base_url: Base URL of the Griptape Cloud.
api_key: API key for the Griptape Cloud.
tool_id: ID of the tool to interact with.
headers: Headers to use for requests.
"""

base_url: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")))
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
tool_id: str = field(kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
init=False,
)

def initialize_tool(self, tool: BaseTool) -> None:
schema = self._get_schema()
tool_name, activity_schemas = self._parse_schema(schema)

if tool.name == tool.__class__.__name__:
tool.name = tool_name

for activity_name, (description, activity_schema) in activity_schemas.items():
activity_handler = self._create_activity_handler(activity_name, description, activity_schema)

setattr(tool, activity_name, MethodType(activity_handler, self))

def _get_schema(self) -> dict:
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/openapi")
response = requests.get(url, headers=self.headers)

response.raise_for_status()

return response.json()

def _parse_schema(self, schema: dict) -> tuple[str, dict[str, tuple[str, Schema]]]:
"""Parses an openapi schema into a dictionary of activity names and their respective descriptions + schemas."""
activities = {}

name = schema.get("info", {}).get("title")

for path, path_info in schema.get("paths", {}).items():
if not path.startswith("/activities"):
continue
for method, method_info in path_info.items():
if "post" in method.lower():
activity_name = method_info["operationId"]
description = method_info.get("description", "")

activity_schema = self.__extract_schema_from_ref(
schema,
method_info.get("requestBody", {})
.get("content", {})
.get("application/json", {})
.get("schema", {}),
)

activities[activity_name] = (description, activity_schema)

return name, activities

def __extract_schema_from_ref(self, schema: dict, schema_ref: dict) -> Schema:
"""Extracts a schema from a $ref if present, resolving it into native schema properties."""
if "$ref" in schema_ref:
# Resolve the reference and retrieve the schema data
ref_path = schema_ref["$ref"].split("/")[-1]
schema_data = schema["components"]["schemas"].get(ref_path, {})
else:
# Use the provided schema directly if no $ref is found
schema_data = schema_ref

# Convert the schema_data dictionary into a Schema with its properties
properties = {}
for prop, prop_info in schema_data.get("properties", {}).items():
prop_type = prop_info.get("type", "string")
prop_description = prop_info.get("description", "")
schema_prop = Literal(prop, description=prop_description)
is_optional = prop not in schema_data.get("required", [])

if is_optional:
schema_prop = Optional(schema_prop)

properties[schema_prop] = self._map_openapi_type_to_python(prop_type)

return Schema(properties)

def _map_openapi_type_to_python(self, openapi_type: str) -> type:
"""Maps OpenAPI types to native Python types."""
type_mapping = {"string": str, "integer": int, "boolean": bool, "number": float, "array": list, "object": dict}

return type_mapping.get(openapi_type, str)

def _create_activity_handler(self, activity_name: str, description: str, activity_schema: Schema) -> Callable:
"""Creates an activity handler method for the tool."""

@activity(config={"name": activity_name, "description": description, "schema": activity_schema})
def activity_handler(_: BaseTool, values: dict) -> Any:
return self._run_activity(activity_name, values)

return activity_handler

def _run_activity(self, activity_name: str, params: dict) -> BaseArtifact:
"""Runs an activity on the tool with the provided parameters."""
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/activities/{activity_name}")

response = requests.post(url, json=params, headers=self.headers)

response.raise_for_status()

try:
return BaseArtifact.from_dict(response.json())
except ValueError:
return TextArtifact(response.text)
2 changes: 2 additions & 0 deletions griptape/schemas/base_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
BasePromptDriver,
BaseRulesetDriver,
BaseTextToSpeechDriver,
BaseToolDriver,
BaseVectorStoreDriver,
)
from griptape.events import EventListener
Expand All @@ -192,6 +193,7 @@ def _resolve_types(cls, attrs_cls: type) -> None:
"BaseConversationMemoryDriver": BaseConversationMemoryDriver,
"BaseRulesetDriver": BaseRulesetDriver,
"BaseImageGenerationDriver": BaseImageGenerationDriver,
"BaseToolDriver": BaseToolDriver,
"BaseArtifact": BaseArtifact,
"PromptStack": PromptStack,
"EventListener": EventListener,
Expand Down
5 changes: 5 additions & 0 deletions griptape/tools/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if TYPE_CHECKING:
from griptape.common import ToolAction
from griptape.drivers import BaseToolDriver
from griptape.memory import TaskMemory
from griptape.tasks import ActionsSubtask

Expand Down Expand Up @@ -56,6 +57,7 @@ class BaseTool(ActivityMixin, SerializableMixin, RunnableMixin["BaseTool"], ABC)
dependencies_install_directory: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
verbose: bool = field(default=False, kw_only=True, metadata={"serializable": True})
off_prompt: bool = field(default=False, kw_only=True, metadata={"serializable": True})
tool_driver: BaseToolDriver = field(default=None, kw_only=True)

def __attrs_post_init__(self) -> None:
if (
Expand All @@ -65,6 +67,9 @@ def __attrs_post_init__(self) -> None:
):
self.install_dependencies(os.environ.copy())

if self.tool_driver is not None:
self.tool_driver.initialize_tool(self)

@output_memory.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_output_memory(self, _: Attribute, output_memory: dict[str, Optional[list[TaskMemory]]]) -> None:
if output_memory:
Expand Down
Loading

0 comments on commit 74a2937

Please sign in to comment.