Skip to content

Commit

Permalink
Add griptape cloud tool tool
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Nov 21, 2024
1 parent c016350 commit ded439b
Show file tree
Hide file tree
Showing 9 changed files with 532 additions and 5 deletions.
4 changes: 1 addition & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `AssistantTask` for running Assistants in Structures.
- `GriptapeCloudAssistantDriver` for interacting with Griptape Cloud's Assistant API.
- `OpenAiAssistantDriver` for interacting with OpenAI's Assistant API.
- New class of Drivers, Tool Drivers, for augmenting Tools with additional functionality.
- `GriptapeCloudToolDriver` for interacting with Griptape Cloud's Tool API.
- `BaseTool.tool_driver` for setting a Tool Driver on a Tool. The Tool will call `tool_driver.initialize_tool` after Tool initialization.
- `GriptapeCloudToolTool` for running Griptape Cloud hosted Tools.

### Changed

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# Griptape Cloud Tool Tool

The [GriptapeCloudToolTool](../../reference/griptape/tools/griptape_cloud_tool/tool.md) integrates with Griptape Cloud's hosted Tools.

**Note:** This tool requires a [Tool](https://cloud.griptape.ai/tools) hosted in Griptape Cloud and an [API Key](https://cloud.griptape.ai/configuration/api-keys) for access.

```python
--8<-- "docs/griptape-tools/official-tools/src/griptape_cloud_tool_tool.py"
```
13 changes: 13 additions & 0 deletions docs/griptape-tools/official-tools/src/griptape_cloud_tool_tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import os

from griptape.structures import Agent
from griptape.tools.griptape_cloud_tool.tool import GriptapeCloudToolTool

agent = Agent(
tools=[
GriptapeCloudToolTool( # Tool is configured as a random number generator
tool_id=os.environ["GT_CLOUD_TOOL_ID"],
)
]
)
agent.run("Generate a number between 1 and 10")
2 changes: 2 additions & 0 deletions griptape/tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .variation_image_generation.tool import VariationImageGenerationTool
from .inpainting_image_generation.tool import InpaintingImageGenerationTool
from .outpainting_image_generation.tool import OutpaintingImageGenerationTool
from .griptape_cloud_tool.tool import GriptapeCloudToolTool
from .structure_run.tool import StructureRunTool
from .image_query.tool import ImageQueryTool
from .rag.tool import RagTool
Expand All @@ -40,6 +41,7 @@
"VariationImageGenerationTool",
"InpaintingImageGenerationTool",
"OutpaintingImageGenerationTool",
"GriptapeCloudToolTool",
"StructureRunTool",
"ImageQueryTool",
"RagTool",
Expand Down
9 changes: 7 additions & 2 deletions griptape/tools/base_griptape_cloud_tool.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import os
from abc import ABC
from typing import Optional

from attrs import Factory, define, field

Expand All @@ -17,8 +19,11 @@ class BaseGriptapeCloudTool(BaseTool, ABC):
headers: Headers for the Griptape Cloud Knowledge Base API.
"""

base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
api_key: str = field(kw_only=True)
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
kw_only=True,
)
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True)
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
kw_only=True,
Expand Down
Empty file.
125 changes: 125 additions & 0 deletions griptape/tools/griptape_cloud_tool/tool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
from __future__ import annotations

from types import MethodType
from typing import Any, Callable
from urllib.parse import urljoin

import requests
from attrs import define, field
from schema import Literal, Optional, Schema

from griptape.artifacts import BaseArtifact, TextArtifact
from griptape.tools.base_griptape_cloud_tool import BaseGriptapeCloudTool
from griptape.utils.decorators import activity


@define()
class GriptapeCloudToolTool(BaseGriptapeCloudTool):
"""Runs a Griptape Cloud hosted Tool.
Attributes:
tool_id: The ID of the tool to run.
"""

tool_id: str = field(kw_only=True)

def __attrs_post_init__(self) -> None:
self._init_activities()

def _init_activities(self) -> None:
schema = self._get_schema()
tool_name, activity_schemas = self._parse_schema(schema)

if self.name == self.__class__.__name__:
self.name = tool_name

for activity_name, (description, activity_schema) in activity_schemas.items():
activity_handler = self._create_activity_handler(activity_name, description, activity_schema)

setattr(self, activity_name, MethodType(activity_handler, self))

def _get_schema(self) -> dict:
response = requests.get(urljoin(self.base_url, f"/api/tools/{self.tool_id}/openapi"), headers=self.headers)

response.raise_for_status()

return response.json()

def _parse_schema(self, schema: dict) -> tuple[str, dict[str, tuple[str, Schema]]]:
"""Parses an openapi schema into a dictionary of activity names and their respective descriptions + schemas."""
activities = {}

name = schema.get("info", {}).get("title")

for path, path_info in schema.get("paths", {}).items():
if not path.startswith("/activities"):
continue
for method, method_info in path_info.items():
if "post" in method.lower():
activity_name = method_info["operationId"]
description = method_info.get("description", "")

activity_schema = self.__extract_schema_from_ref(
schema,
method_info.get("requestBody", {})
.get("content", {})
.get("application/json", {})
.get("schema", {}),
)

activities[activity_name] = (description, activity_schema)

return name, activities

def __extract_schema_from_ref(self, schema: dict, schema_ref: dict) -> Schema:
"""Extracts a schema from a $ref if present, resolving it into native schema properties."""
if "$ref" in schema_ref:
# Resolve the reference and retrieve the schema data
ref_path = schema_ref["$ref"].split("/")[-1]
schema_data = schema["components"]["schemas"].get(ref_path, {})
else:
# Use the provided schema directly if no $ref is found
schema_data = schema_ref

# Convert the schema_data dictionary into a Schema with its properties
properties = {}
for prop, prop_info in schema_data.get("properties", {}).items():
prop_type = prop_info.get("type", "string")
prop_description = prop_info.get("description", "")
schema_prop = Literal(prop, description=prop_description)
is_optional = prop not in schema_data.get("required", [])

if is_optional:
schema_prop = Optional(schema_prop)

properties[schema_prop] = self._map_openapi_type_to_python(prop_type)

return Schema(properties)

def _map_openapi_type_to_python(self, openapi_type: str) -> type:
"""Maps OpenAPI types to native Python types."""
type_mapping = {"string": str, "integer": int, "boolean": bool, "number": float, "array": list, "object": dict}

return type_mapping.get(openapi_type, str)

def _create_activity_handler(self, activity_name: str, description: str, activity_schema: Schema) -> Callable:
"""Creates an activity handler method for the tool."""

@activity(config={"name": activity_name, "description": description, "schema": activity_schema})
def activity_handler(self: GriptapeCloudToolTool, values: dict) -> Any:
return self._run_activity(activity_name, values)

return activity_handler

def _run_activity(self, activity_name: str, params: dict) -> BaseArtifact:
"""Runs an activity on the tool with the provided parameters."""
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/activities/{activity_name}")

response = requests.post(url, json=params, headers=self.headers)

response.raise_for_status()

try:
return BaseArtifact.from_dict(response.json())
except ValueError:
return TextArtifact(response.text)
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ nav:
- Image Query: "griptape-tools/official-tools/image-query-tool.md"
- Text To Speech: "griptape-tools/official-tools/text-to-speech-tool.md"
- Audio Transcription: "griptape-tools/official-tools/audio-transcription-tool.md"
- Griptape Cloud Tool: "griptape-tools/official-tools/griptape-cloud-tool-tool.md"
- Rag: "griptape-tools/official-tools/rag-tool.md"
- Extraction: "griptape-tools/official-tools/extraction-tool.md"
- Query: "griptape-tools/official-tools/query-tool.md"
Expand Down
Loading

0 comments on commit ded439b

Please sign in to comment.