-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3c4b414
commit f261cd6
Showing
17 changed files
with
680 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_1.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from griptape.drivers import GriptapeCloudToolDriver | ||
from griptape.structures import Agent | ||
from griptape.tools import CalculatorTool | ||
|
||
tool = CalculatorTool() | ||
driver = GriptapeCloudToolDriver( | ||
api_key=os.environ["GT_CLOUD_API_KEY"], | ||
tool_id=os.environ["GT_CLOUD_TOOL_ID"], | ||
) | ||
driver.initialize_tool(tool) | ||
|
||
agent = Agent( | ||
tools=[tool], | ||
) | ||
|
||
agent.run("What is 2 ^ 2 + 12 and 5 * 145369?") |
41 changes: 41 additions & 0 deletions
41
docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_2.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
|
||
import schema | ||
|
||
from griptape.artifacts.text_artifact import TextArtifact | ||
from griptape.drivers import GriptapeCloudToolDriver | ||
from griptape.structures import Agent | ||
from griptape.tools import CalculatorTool | ||
from griptape.utils.decorators import activity | ||
|
||
|
||
class ExtendedCalculator(CalculatorTool): | ||
@activity( | ||
config={ | ||
"description": "Convert between fahrenheit and celsius", | ||
"schema": schema.Schema( | ||
{ | ||
"temperature": float, | ||
"unit": schema.Or("fahrenheit", "celsius"), # pyright: ignore[reportArgumentType] | ||
} | ||
), | ||
} | ||
) | ||
def convert(self, temperature: float, unit: str) -> TextArtifact: | ||
converted = (temperature - 32) * 5 / 9 if unit == "fahrenheit" else temperature * 9 / 5 + 32 | ||
|
||
return TextArtifact(converted) | ||
|
||
|
||
agent = Agent( | ||
tools=[ | ||
CalculatorTool( | ||
tool_driver=GriptapeCloudToolDriver( | ||
api_key=os.environ["GT_CLOUD_API_KEY"], | ||
tool_id=os.environ["GT_CLOUD_TOOL_ID"], | ||
) | ||
) | ||
], | ||
) | ||
|
||
agent.run("What is 2 ^ 2 + 12 and 5 * 145369 and 30 degrees F to C?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from griptape.drivers import GriptapeCloudToolDriver | ||
from griptape.structures import Agent | ||
from griptape.tools import CalculatorTool | ||
|
||
agent = Agent( | ||
tools=[ | ||
CalculatorTool( | ||
tool_driver=GriptapeCloudToolDriver( | ||
api_key=os.environ["GT_CLOUD_API_KEY"], | ||
tool_id=os.environ["GT_CLOUD_TOOL_ID"], | ||
) | ||
) | ||
], | ||
) | ||
|
||
agent.run("What is 2 ^ 2 + 12 and 5 * 145369?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
--- | ||
search: | ||
boost: 2 | ||
--- | ||
|
||
## Overview | ||
|
||
Tool Drivers can be used to augment [Tools](../../griptape-tools/index.md) with additional functionality. | ||
|
||
You can use Tool Drivers with Tools: | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/tool_driver_1.py" | ||
``` | ||
|
||
In this example, the `CalculatorTool`'s schema is loaded from Griptape Cloud. Additionally, when the `CalculatorTool` is executed, it is run in the Cloud rather than locally. | ||
|
||
## Tool Drivers | ||
|
||
### Griptape Cloud | ||
|
||
The [GriptapeCloudToolDriver](../../reference/griptape/drivers/tool/griptape_cloud_tool_driver.md) uses Griptape Cloud's hosted Tools service to modify the Tool with the hosted Tool's schema and execution logic. | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_1.py" | ||
``` | ||
|
||
The Driver will only overwrite activities present in the hosted Tool, meaning you can mix and match local and hosted activities. | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_2.py" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
from attrs import define | ||
|
||
if TYPE_CHECKING: | ||
from griptape.tools.base_tool import BaseTool | ||
|
||
|
||
@define | ||
class BaseToolDriver(ABC): | ||
@abstractmethod | ||
def initialize_tool(self, tool: BaseTool) -> None: | ||
"""Performs some initialization operations on the Tool. This method is intended to mutate the Tool object. | ||
Args: | ||
tool: Tool to initialize. | ||
""" | ||
|
||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from types import MethodType | ||
from typing import TYPE_CHECKING, Any, Callable | ||
from urllib.parse import urljoin | ||
|
||
import requests | ||
from attrs import Factory, define, field | ||
from schema import Literal, Optional, Schema | ||
|
||
from griptape.artifacts import BaseArtifact, TextArtifact | ||
from griptape.drivers.tool.base_tool_driver import BaseToolDriver | ||
from griptape.utils.decorators import activity | ||
|
||
if TYPE_CHECKING: | ||
from griptape.tools.base_tool import BaseTool | ||
|
||
|
||
@define | ||
class GriptapeCloudToolDriver(BaseToolDriver): | ||
"""Driver for interacting with tools hosted on the Griptape Cloud. | ||
Attributes: | ||
base_url: Base URL of the Griptape Cloud. | ||
api_key: API key for the Griptape Cloud. | ||
tool_id: ID of the tool to interact with. | ||
headers: Headers to use for requests. | ||
raise_on_same_activity: Whether to raise an error if an activity with the same name already exists on the tool. | ||
""" | ||
|
||
base_url: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai"))) | ||
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) | ||
tool_id: str = field(kw_only=True) | ||
headers: dict = field( | ||
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), | ||
init=False, | ||
) | ||
raise_on_same_activity: bool = field(default=True, kw_only=True) | ||
|
||
def initialize_tool(self, tool: BaseTool) -> None: | ||
schema = self._get_schema() | ||
tool_name, activity_schemas = self._parse_schema(schema) | ||
|
||
if tool.name == tool.__class__.__name__: | ||
tool.name = tool_name | ||
|
||
for activity_name, (description, activity_schema) in activity_schemas.items(): | ||
activity_handler = self._create_activity_handler(activity_name, description, activity_schema) | ||
|
||
if hasattr(tool, activity_name) and self.raise_on_same_activity: | ||
raise ValueError(f"Activity {activity_name} already exists on the tool.") | ||
setattr(tool, activity_name, MethodType(activity_handler, self)) | ||
|
||
def _get_schema(self) -> dict: | ||
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/openapi") | ||
response = requests.get(url, headers=self.headers) | ||
|
||
response.raise_for_status() | ||
|
||
return response.json() | ||
|
||
def _parse_schema(self, schema: dict) -> tuple[str, dict[str, tuple[str, Schema]]]: | ||
"""Parses an openapi schema into a dictionary of activity names and their respective descriptions + schemas.""" | ||
activities = {} | ||
|
||
name = schema.get("info", {}).get("title") | ||
|
||
for path, path_info in schema.get("paths", {}).items(): | ||
if not path.startswith("/activities"): | ||
continue | ||
for method, method_info in path_info.items(): | ||
if "post" in method.lower(): | ||
activity_name = method_info["operationId"] | ||
description = method_info.get("description", "") | ||
|
||
activity_schema = self.__extract_schema_from_ref( | ||
schema, | ||
method_info.get("requestBody", {}) | ||
.get("content", {}) | ||
.get("application/json", {}) | ||
.get("schema", {}), | ||
) | ||
|
||
activities[activity_name] = (description, activity_schema) | ||
|
||
return name, activities | ||
|
||
def __extract_schema_from_ref(self, schema: dict, schema_ref: dict) -> Schema: | ||
"""Extracts a schema from a $ref if present, resolving it into native schema properties.""" | ||
if "$ref" in schema_ref: | ||
# Resolve the reference and retrieve the schema data | ||
ref_path = schema_ref["$ref"].split("/")[-1] | ||
schema_data = schema["components"]["schemas"].get(ref_path, {}) | ||
else: | ||
# Use the provided schema directly if no $ref is found | ||
schema_data = schema_ref | ||
|
||
# Convert the schema_data dictionary into a Schema with its properties | ||
properties = {} | ||
for prop, prop_info in schema_data.get("properties", {}).items(): | ||
prop_type = prop_info.get("type", "string") | ||
prop_description = prop_info.get("description", "") | ||
schema_prop = Literal(prop, description=prop_description) | ||
is_optional = prop not in schema_data.get("required", []) | ||
|
||
if is_optional: | ||
schema_prop = Optional(schema_prop) | ||
|
||
properties[schema_prop] = self._map_openapi_type_to_python(prop_type) | ||
|
||
return Schema(properties) | ||
|
||
def _map_openapi_type_to_python(self, openapi_type: str) -> type: | ||
"""Maps OpenAPI types to native Python types.""" | ||
type_mapping = {"string": str, "integer": int, "boolean": bool, "number": float, "array": list, "object": dict} | ||
|
||
return type_mapping.get(openapi_type, str) | ||
|
||
def _create_activity_handler(self, activity_name: str, description: str, activity_schema: Schema) -> Callable: | ||
"""Creates an activity handler method for the tool.""" | ||
|
||
@activity(config={"name": activity_name, "description": description, "schema": activity_schema}) | ||
def activity_handler(_: BaseTool, values: dict) -> Any: | ||
return self._run_activity(activity_name, values) | ||
|
||
return activity_handler | ||
|
||
def _run_activity(self, activity_name: str, params: dict) -> BaseArtifact: | ||
"""Runs an activity on the tool with the provided parameters.""" | ||
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/activities/{activity_name}") | ||
|
||
response = requests.post(url, json=params, headers=self.headers) | ||
|
||
response.raise_for_status() | ||
|
||
try: | ||
return BaseArtifact.from_dict(response.json()) | ||
except ValueError: | ||
return TextArtifact(response.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.