-
Notifications
You must be signed in to change notification settings - Fork 184
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
3c4b414
commit abc749b
Showing
15 changed files
with
608 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
18 changes: 18 additions & 0 deletions
18
docs/griptape-framework/drivers/src/griptape_cloud_tool_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from griptape.drivers import GriptapeCloudToolDriver | ||
from griptape.structures import Agent | ||
from griptape.tools import CalculatorTool | ||
|
||
tool = CalculatorTool() | ||
driver = GriptapeCloudToolDriver( | ||
api_key=os.environ["GT_CLOUD_API_KEY"], | ||
tool_id=os.environ["GT_CLOUD_TOOL_ID"], | ||
) | ||
driver.initialize_tool(tool) | ||
|
||
agent = Agent( | ||
tools=[tool], | ||
) | ||
|
||
agent.run("What is 2 ^ 2 + 12 and 5 * 145369?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from griptape.drivers import GriptapeCloudToolDriver | ||
from griptape.structures import Agent | ||
from griptape.tools import CalculatorTool | ||
|
||
agent = Agent( | ||
tools=[ | ||
CalculatorTool( | ||
tool_driver=GriptapeCloudToolDriver( | ||
api_key=os.environ["GT_CLOUD_API_KEY"], | ||
tool_id=os.environ["GT_CLOUD_TOOL_ID"], | ||
) | ||
) | ||
], | ||
) | ||
|
||
agent.run("What is 2 ^ 2 + 12 and 5 * 145369?") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
--- | ||
search: | ||
boost: 2 | ||
--- | ||
|
||
## Overview | ||
|
||
Tool Drivers can be used to augment [Tools](../../griptape-tools/index.md) with additional functionality. | ||
|
||
You can use Tool Drivers with Tools: | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/tool_driver.py" | ||
``` | ||
|
||
In this example, the `CalculatorTool`'s schema is loaded from Griptape Cloud. Additionally, when the `CalculatorTool` is executed, it is run in the Cloud rather than locally. | ||
|
||
## Tool Drivers | ||
|
||
### Griptape Cloud | ||
|
||
The [GriptapeCloudToolDriver](../../reference/griptape/drivers/tool/griptape_cloud_tool_driver.md) uses Griptape Cloud's hosted Tools service to modify the Tool with the hosted Tool's schema and execution logic. | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver.py" | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import TYPE_CHECKING | ||
|
||
from attrs import define | ||
|
||
if TYPE_CHECKING: | ||
from griptape.tools.base_tool import BaseTool | ||
|
||
|
||
@define | ||
class BaseToolDriver(ABC): | ||
@abstractmethod | ||
def initialize_tool(self, tool: BaseTool) -> None: | ||
"""Performs some initialization operations on the Tool. This method is intended to mutate the Tool object. | ||
Args: | ||
tool: Tool to initialize. | ||
""" | ||
|
||
... |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from types import MethodType | ||
from typing import TYPE_CHECKING, Any, Callable | ||
from urllib.parse import urljoin | ||
|
||
import requests | ||
from attrs import Factory, define, field | ||
from schema import Literal, Optional, Schema | ||
|
||
from griptape.artifacts import BaseArtifact, TextArtifact | ||
from griptape.drivers.tool.base_tool_driver import BaseToolDriver | ||
from griptape.utils.decorators import activity | ||
|
||
if TYPE_CHECKING: | ||
from griptape.tools.base_tool import BaseTool | ||
|
||
|
||
@define | ||
class GriptapeCloudToolDriver(BaseToolDriver): | ||
base_url: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai"))) | ||
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) | ||
tool_id: str = field(kw_only=True) | ||
headers: dict = field( | ||
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), | ||
init=False, | ||
) | ||
|
||
def initialize_tool(self, tool: BaseTool) -> None: | ||
schema = self._get_schema() | ||
tool_name, activity_schemas = self._parse_schema(schema) | ||
|
||
if tool.name == tool.__class__.__name__: | ||
tool.name = tool_name | ||
|
||
for activity_name, (description, activity_schema) in activity_schemas.items(): | ||
activity_handler = self._create_activity_handler(activity_name, description, activity_schema) | ||
|
||
setattr(tool, activity_name, MethodType(activity_handler, self)) | ||
|
||
def _get_schema(self) -> dict: | ||
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/openapi") | ||
response = requests.get(url, headers=self.headers) | ||
|
||
response.raise_for_status() | ||
|
||
return response.json() | ||
|
||
def _parse_schema(self, schema: dict) -> tuple[str, dict[str, tuple[str, Schema]]]: | ||
"""Parses an openapi schema into a dictionary of activity names and their respective descriptions + schemas.""" | ||
activities = {} | ||
|
||
name = schema.get("info", {}).get("title") | ||
|
||
for path, path_info in schema.get("paths", {}).items(): | ||
if not path.startswith("/activities"): | ||
continue | ||
for method, method_info in path_info.items(): | ||
if "post" in method.lower(): | ||
activity_name = method_info["operationId"] | ||
description = method_info.get("description", "") | ||
|
||
activity_schema = self.__extract_schema_from_ref( | ||
schema, | ||
method_info.get("requestBody", {}) | ||
.get("content", {}) | ||
.get("application/json", {}) | ||
.get("schema", {}), | ||
) | ||
|
||
activities[activity_name] = (description, activity_schema) | ||
|
||
return name, activities | ||
|
||
def __extract_schema_from_ref(self, schema: dict, schema_ref: dict) -> Schema: | ||
"""Extracts a schema from a $ref if present, resolving it into native schema properties.""" | ||
if "$ref" in schema_ref: | ||
# Resolve the reference and retrieve the schema data | ||
ref_path = schema_ref["$ref"].split("/")[-1] | ||
schema_data = schema["components"]["schemas"].get(ref_path, {}) | ||
else: | ||
# Use the provided schema directly if no $ref is found | ||
schema_data = schema_ref | ||
|
||
# Convert the schema_data dictionary into a Schema with its properties | ||
properties = {} | ||
for prop, prop_info in schema_data.get("properties", {}).items(): | ||
prop_type = prop_info.get("type", "string") | ||
prop_description = prop_info.get("description", "") | ||
schema_prop = Literal(prop, description=prop_description) | ||
is_optional = prop not in schema_data.get("required", []) | ||
|
||
if is_optional: | ||
schema_prop = Optional(schema_prop) | ||
|
||
properties[schema_prop] = self._map_openapi_type_to_python(prop_type) | ||
|
||
return Schema(properties) | ||
|
||
def _map_openapi_type_to_python(self, openapi_type: str) -> type: | ||
"""Maps OpenAPI types to native Python types.""" | ||
type_mapping = {"string": str, "integer": int, "boolean": bool, "number": float, "array": list, "object": dict} | ||
|
||
return type_mapping.get(openapi_type, str) | ||
|
||
def _create_activity_handler(self, activity_name: str, description: str, activity_schema: Schema) -> Callable: | ||
"""Creates an activity handler method for the tool.""" | ||
|
||
@activity(config={"name": activity_name, "description": description, "schema": activity_schema}) | ||
def activity_handler(_: BaseTool, values: dict) -> Any: | ||
return self._run_activity(activity_name, values) | ||
|
||
return activity_handler | ||
|
||
def _run_activity(self, activity_name: str, params: dict) -> BaseArtifact: | ||
"""Runs an activity on the tool with the provided parameters.""" | ||
url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/activities/{activity_name}") | ||
|
||
response = requests.post(url, json=params, headers=self.headers) | ||
|
||
response.raise_for_status() | ||
|
||
try: | ||
return BaseArtifact.from_dict(response.json()) | ||
except ValueError: | ||
return TextArtifact(response.text) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
Oops, something went wrong.