From ded439b23a6bfca2b222cfd60959b398d6ce633f Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Tue, 12 Nov 2024 16:50:00 -0800 Subject: [PATCH] Add griptape cloud tool tool --- CHANGELOG.md | 4 +- .../griptape-cloud-tool-tool.md | 9 + .../src/griptape_cloud_tool_tool.py | 13 + griptape/tools/__init__.py | 2 + griptape/tools/base_griptape_cloud_tool.py | 9 +- .../tools/griptape_cloud_tool/__init__.py | 0 griptape/tools/griptape_cloud_tool/tool.py | 125 ++++++ mkdocs.yml | 1 + .../tools/test_griptape_cloud_tool_tool.py | 374 ++++++++++++++++++ 9 files changed, 532 insertions(+), 5 deletions(-) create mode 100644 docs/griptape-tools/official-tools/griptape-cloud-tool-tool.md create mode 100644 docs/griptape-tools/official-tools/src/griptape_cloud_tool_tool.py create mode 100644 griptape/tools/griptape_cloud_tool/__init__.py create mode 100644 griptape/tools/griptape_cloud_tool/tool.py create mode 100644 tests/unit/tools/test_griptape_cloud_tool_tool.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 128242ea7..e3742a8d2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,9 +16,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AssistantTask` for running Assistants in Structures. - `GriptapeCloudAssistantDriver` for interacting with Griptape Cloud's Assistant API. - `OpenAiAssistantDriver` for interacting with OpenAI's Assistant API. -- New class of Drivers, Tool Drivers, for augmenting Tools with additional functionality. -- `GriptapeCloudToolDriver` for interacting with Griptape Cloud's Tool API. -- `BaseTool.tool_driver` for setting a Tool Driver on a Tool. The Tool will call `tool_driver.initialize_tool` after Tool initialization. +- `GriptapeCloudToolTool` for running Griptape Cloud hosted Tools. ### Changed diff --git a/docs/griptape-tools/official-tools/griptape-cloud-tool-tool.md b/docs/griptape-tools/official-tools/griptape-cloud-tool-tool.md new file mode 100644 index 000000000..d2dda4120 --- /dev/null +++ b/docs/griptape-tools/official-tools/griptape-cloud-tool-tool.md @@ -0,0 +1,9 @@ +# Griptape Cloud Tool Tool + +The [GriptapeCloudToolTool](../../reference/griptape/tools/griptape_cloud_tool/tool.md) integrates with Griptape Cloud's hosted Tools. + +**Note:** This tool requires a [Tool](https://cloud.griptape.ai/tools) hosted in Griptape Cloud and an [API Key](https://cloud.griptape.ai/configuration/api-keys) for access. + +```python +--8<-- "docs/griptape-tools/official-tools/src/griptape_cloud_tool_tool.py" +``` diff --git a/docs/griptape-tools/official-tools/src/griptape_cloud_tool_tool.py b/docs/griptape-tools/official-tools/src/griptape_cloud_tool_tool.py new file mode 100644 index 000000000..cabf7f1ab --- /dev/null +++ b/docs/griptape-tools/official-tools/src/griptape_cloud_tool_tool.py @@ -0,0 +1,13 @@ +import os + +from griptape.structures import Agent +from griptape.tools.griptape_cloud_tool.tool import GriptapeCloudToolTool + +agent = Agent( + tools=[ + GriptapeCloudToolTool( # Tool is configured as a random number generator + tool_id=os.environ["GT_CLOUD_TOOL_ID"], + ) + ] +) +agent.run("Generate a number between 1 and 10") diff --git a/griptape/tools/__init__.py b/griptape/tools/__init__.py index 6eb21e921..67a1712a1 100644 --- a/griptape/tools/__init__.py +++ b/griptape/tools/__init__.py @@ -14,6 +14,7 @@ from .variation_image_generation.tool import VariationImageGenerationTool from .inpainting_image_generation.tool import InpaintingImageGenerationTool from .outpainting_image_generation.tool import OutpaintingImageGenerationTool +from .griptape_cloud_tool.tool import GriptapeCloudToolTool from .structure_run.tool import StructureRunTool from .image_query.tool import ImageQueryTool from .rag.tool import RagTool @@ -40,6 +41,7 @@ "VariationImageGenerationTool", "InpaintingImageGenerationTool", "OutpaintingImageGenerationTool", + "GriptapeCloudToolTool", "StructureRunTool", "ImageQueryTool", "RagTool", diff --git a/griptape/tools/base_griptape_cloud_tool.py b/griptape/tools/base_griptape_cloud_tool.py index 7ee8f2dfc..d46641ee6 100644 --- a/griptape/tools/base_griptape_cloud_tool.py +++ b/griptape/tools/base_griptape_cloud_tool.py @@ -1,6 +1,8 @@ from __future__ import annotations +import os from abc import ABC +from typing import Optional from attrs import Factory, define, field @@ -17,8 +19,11 @@ class BaseGriptapeCloudTool(BaseTool, ABC): headers: Headers for the Griptape Cloud Knowledge Base API. """ - base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) - api_key: str = field(kw_only=True) + base_url: str = field( + default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), + kw_only=True, + ) + api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")), kw_only=True) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True, diff --git a/griptape/tools/griptape_cloud_tool/__init__.py b/griptape/tools/griptape_cloud_tool/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/tools/griptape_cloud_tool/tool.py b/griptape/tools/griptape_cloud_tool/tool.py new file mode 100644 index 000000000..ca54df3f4 --- /dev/null +++ b/griptape/tools/griptape_cloud_tool/tool.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from types import MethodType +from typing import Any, Callable +from urllib.parse import urljoin + +import requests +from attrs import define, field +from schema import Literal, Optional, Schema + +from griptape.artifacts import BaseArtifact, TextArtifact +from griptape.tools.base_griptape_cloud_tool import BaseGriptapeCloudTool +from griptape.utils.decorators import activity + + +@define() +class GriptapeCloudToolTool(BaseGriptapeCloudTool): + """Runs a Griptape Cloud hosted Tool. + + Attributes: + tool_id: The ID of the tool to run. + """ + + tool_id: str = field(kw_only=True) + + def __attrs_post_init__(self) -> None: + self._init_activities() + + def _init_activities(self) -> None: + schema = self._get_schema() + tool_name, activity_schemas = self._parse_schema(schema) + + if self.name == self.__class__.__name__: + self.name = tool_name + + for activity_name, (description, activity_schema) in activity_schemas.items(): + activity_handler = self._create_activity_handler(activity_name, description, activity_schema) + + setattr(self, activity_name, MethodType(activity_handler, self)) + + def _get_schema(self) -> dict: + response = requests.get(urljoin(self.base_url, f"/api/tools/{self.tool_id}/openapi"), headers=self.headers) + + response.raise_for_status() + + return response.json() + + def _parse_schema(self, schema: dict) -> tuple[str, dict[str, tuple[str, Schema]]]: + """Parses an openapi schema into a dictionary of activity names and their respective descriptions + schemas.""" + activities = {} + + name = schema.get("info", {}).get("title") + + for path, path_info in schema.get("paths", {}).items(): + if not path.startswith("/activities"): + continue + for method, method_info in path_info.items(): + if "post" in method.lower(): + activity_name = method_info["operationId"] + description = method_info.get("description", "") + + activity_schema = self.__extract_schema_from_ref( + schema, + method_info.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}), + ) + + activities[activity_name] = (description, activity_schema) + + return name, activities + + def __extract_schema_from_ref(self, schema: dict, schema_ref: dict) -> Schema: + """Extracts a schema from a $ref if present, resolving it into native schema properties.""" + if "$ref" in schema_ref: + # Resolve the reference and retrieve the schema data + ref_path = schema_ref["$ref"].split("/")[-1] + schema_data = schema["components"]["schemas"].get(ref_path, {}) + else: + # Use the provided schema directly if no $ref is found + schema_data = schema_ref + + # Convert the schema_data dictionary into a Schema with its properties + properties = {} + for prop, prop_info in schema_data.get("properties", {}).items(): + prop_type = prop_info.get("type", "string") + prop_description = prop_info.get("description", "") + schema_prop = Literal(prop, description=prop_description) + is_optional = prop not in schema_data.get("required", []) + + if is_optional: + schema_prop = Optional(schema_prop) + + properties[schema_prop] = self._map_openapi_type_to_python(prop_type) + + return Schema(properties) + + def _map_openapi_type_to_python(self, openapi_type: str) -> type: + """Maps OpenAPI types to native Python types.""" + type_mapping = {"string": str, "integer": int, "boolean": bool, "number": float, "array": list, "object": dict} + + return type_mapping.get(openapi_type, str) + + def _create_activity_handler(self, activity_name: str, description: str, activity_schema: Schema) -> Callable: + """Creates an activity handler method for the tool.""" + + @activity(config={"name": activity_name, "description": description, "schema": activity_schema}) + def activity_handler(self: GriptapeCloudToolTool, values: dict) -> Any: + return self._run_activity(activity_name, values) + + return activity_handler + + def _run_activity(self, activity_name: str, params: dict) -> BaseArtifact: + """Runs an activity on the tool with the provided parameters.""" + url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/activities/{activity_name}") + + response = requests.post(url, json=params, headers=self.headers) + + response.raise_for_status() + + try: + return BaseArtifact.from_dict(response.json()) + except ValueError: + return TextArtifact(response.text) diff --git a/mkdocs.yml b/mkdocs.yml index 892069431..503d44445 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -161,6 +161,7 @@ nav: - Image Query: "griptape-tools/official-tools/image-query-tool.md" - Text To Speech: "griptape-tools/official-tools/text-to-speech-tool.md" - Audio Transcription: "griptape-tools/official-tools/audio-transcription-tool.md" + - Griptape Cloud Tool: "griptape-tools/official-tools/griptape-cloud-tool-tool.md" - Rag: "griptape-tools/official-tools/rag-tool.md" - Extraction: "griptape-tools/official-tools/extraction-tool.md" - Query: "griptape-tools/official-tools/query-tool.md" diff --git a/tests/unit/tools/test_griptape_cloud_tool_tool.py b/tests/unit/tools/test_griptape_cloud_tool_tool.py new file mode 100644 index 000000000..51e8c0767 --- /dev/null +++ b/tests/unit/tools/test_griptape_cloud_tool_tool.py @@ -0,0 +1,374 @@ +import json + +import pytest +import schema + +from griptape.artifacts.text_artifact import TextArtifact +from griptape.tools import GriptapeCloudToolTool + +MOCK_SCHEMA = { + "openapi": "3.1.0", + "info": {"title": "DataProcessor", "version": "0.2.0"}, + "servers": [{"url": "https://cloud.griptape.ai/api/tools/1"}], + "paths": { + "/activities/processString": { + "post": { + "tags": ["Activities"], + "summary": "Process String", + "description": "Processes a string input", + "operationId": "processString", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/StringInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "String processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processNumber": { + "post": { + "tags": ["Activities"], + "summary": "Process Number", + "description": "Processes a number input", + "operationId": "processNumber", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/NumberInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Number processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processBoolean": { + "post": { + "tags": ["Activities"], + "summary": "Process Boolean", + "description": "Processes a boolean input", + "operationId": "processBoolean", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BooleanInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Boolean processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processArray": { + "post": { + "tags": ["Activities"], + "summary": "Process Array", + "description": "Processes an array input", + "operationId": "processArray", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/ArrayInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Array processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processObject": { + "post": { + "tags": ["Activities"], + "summary": "Process Object", + "description": "Processes an object input", + "operationId": "processObject", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/ObjectInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Object processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processRequired": { + "post": { + "tags": ["Activities"], + "summary": "Process Required", + "description": "Processes a required input", + "operationId": "processRequired", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/RequiredInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Required processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processNoRef": { + "post": { + "tags": ["Activities"], + "summary": "Process No Ref", + "description": "Processes a no ref input", + "operationId": "processNoRef", + "requestBody": { + "content": { + "application/json": { + "schema": { + "properties": { + "text": {"type": "string", "title": "Text", "description": "The string to process"}, + }, + "type": "object", + "title": "StringInput", + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Required processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processWrongMethod": {"get": {}}, + "/openapi": { + "get": { + "tags": ["OpenAPI"], + "summary": "OpenAPI Specification", + "description": "Get the OpenAPI specification for this tool", + "operationId": "openapi", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {"type": "object", "title": "Response Openapi"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + }, + "components": { + "schemas": { + "BaseArtifact": { + "properties": { + "value": {"title": "Value", "description": "The return value of the activity"}, + "type": {"type": "string", "title": "Type", "description": "The type of the return value"}, + }, + "type": "object", + "required": ["value", "type"], + "title": "BaseArtifact", + }, + "ErrorResponse": { + "properties": { + "error": {"type": "string", "title": "Error", "description": "The return value of the activity"} + }, + "type": "object", + "required": ["error"], + "title": "ErrorResponse", + }, + "StringInput": { + "properties": { + "text": {"type": "string", "title": "Text", "description": "The string to process"}, + }, + "type": "object", + "title": "StringInput", + }, + "NumberInput": { + "properties": { + "number": {"type": "number", "title": "Number", "description": "The number to process"}, + }, + "type": "object", + "title": "NumberInput", + }, + "BooleanInput": { + "properties": { + "flag": {"type": "boolean", "title": "Flag", "description": "The boolean to process"}, + }, + "type": "object", + "title": "BooleanInput", + }, + "ArrayInput": { + "properties": { + "items": { + "type": "array", + "title": "Items", + "description": "An array of numbers", + "items": {"type": "number"}, + } + }, + "type": "object", + "title": "ArrayInput", + }, + "ObjectInput": { + "properties": { + "data": { + "type": "object", + "title": "Data", + "description": "An object containing key-value pairs", + "additionalProperties": {"type": "string"}, + } + }, + "type": "object", + "title": "ObjectInput", + }, + "RequiredInput": { + "properties": { + "required": { + "type": "string", + "title": "Required", + "description": "A required input field", + } + }, + "type": "object", + "required": ["required"], + "title": "RequiredInput", + }, + }, + "securitySchemes": {"bearerAuth": {"type": "http", "scheme": "bearer"}}, + }, +} + + +class TestGriptapeCloudToolTool: + @pytest.fixture() + def mock_schema(self): + return MOCK_SCHEMA + + @pytest.fixture(autouse=True) + def _mock_requests_get(self, mocker, mock_schema): + mock_get = mocker.patch("requests.get") + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = mock_schema + + @pytest.fixture(autouse=True) + def _mock_requests_post(self, mocker): + mock_get = mocker.patch("requests.post") + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = TextArtifact("foo").to_dict() + + def test_init(self): + tool = GriptapeCloudToolTool(tool_id="tool_id") + + # Define expected activity details for each method + expected_activities = { + "processString": { + "name": "processString", + "description": "Processes a string input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("text", description="The string to process")): str} + ), + }, + "processNumber": { + "name": "processNumber", + "description": "Processes a number input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("number", description="The number to process")): float} + ), + }, + "processBoolean": { + "name": "processBoolean", + "description": "Processes a boolean input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("flag", description="The boolean to process")): bool} + ), + }, + "processArray": { + "name": "processArray", + "description": "Processes an array input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("items", description="An array of numbers")): list} + ), + }, + "processObject": { + "name": "processObject", + "description": "Processes an object input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("data", description="An object containing key-value pairs")): dict} + ), + }, + "processRequired": { + "name": "processRequired", + "description": "Processes a required input", + "schema": schema.Schema({schema.Literal("required", description="A required input field"): str}), + }, + "processNoRef": { + "name": "processNoRef", + "description": "Processes a no ref input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("text", description="The string to process")): str} + ), + }, + } + + for activity_name, details in expected_activities.items(): + assert hasattr(tool, activity_name), f"Method {activity_name} does not exist in the tool." + activity = getattr(tool, activity_name) + + assert getattr(activity, "name") == details["name"] + assert getattr(activity, "config")["name"] == details["name"] + assert getattr(activity, "config")["description"] == details["description"] + assert getattr(activity, "config")["schema"].json_schema("Schema") == details["schema"].json_schema( + "Schema" + ) + + assert getattr(activity, "is_activity") is True + + def test_multiple_init(self, mock_schema): + tool_1 = GriptapeCloudToolTool(tool_id="tool_id_1") + mock_schema["paths"]["/activities/processString"]["post"]["description"] = "new description" + tool_2 = GriptapeCloudToolTool(tool_id="tool_id_2") + + assert getattr(tool_1, "processString") != getattr(tool_2, "processString") + assert getattr(tool_1, "processString").config["description"] == "Processes a string input" + assert getattr(tool_2, "processString").config["description"] == "new description" + + def test_run_activity(self): + tool = GriptapeCloudToolTool(tool_id="tool_id") + response = tool.processString({"text": "foo"}) # pyright: ignore[reportAttributeAccessIssue] + + assert response.value == "foo" + + def test_name(self): + tool = GriptapeCloudToolTool(tool_id="tool_id") + + assert tool.name == "DataProcessor" + + tool = GriptapeCloudToolTool(tool_id="tool_id", name="CustomName") + + assert tool.name == "CustomName" + + def test_bad_artifact(self, mocker): + mock_get = mocker.patch("requests.post") + mock_get.return_value.status_code = 200 + return_value = {"type": "FooBarArtifact", "value": "foo"} + mock_get.return_value.json.return_value = return_value + mock_get.return_value.text = json.dumps(return_value) + tool = GriptapeCloudToolTool(tool_id="tool_id") + + result = tool.processString({"text": 1}) + assert isinstance(result, TextArtifact) + assert result.value == json.dumps(return_value)