diff --git a/.github/workflows/docs-integration-tests.yml b/.github/workflows/docs-integration-tests.yml index e765f2582..06b0f0066 100644 --- a/.github/workflows/docs-integration-tests.yml +++ b/.github/workflows/docs-integration-tests.yml @@ -68,6 +68,7 @@ jobs: GT_CLOUD_STRUCTURE_ID: ${{ vars.INTEG_GT_CLOUD_STRUCTURE_ID }} GT_CLOUD_STRUCTURE_RUN_ID: ${{ vars.INTEG_GT_CLOUD_STRUCTURE_RUN_ID }} GT_CLOUD_THREAD_ID: ${{ vars.INTEG_GT_CLOUD_THREAD_ID }} + GT_CLOUD_TOOL_ID: ${{ vars.INTEG_GT_CLOUD_TOOL_ID }} HUGGINGFACE_HUB_ACCESS_TOKEN: ${{ secrets.INTEG_HUGGINGFACE_HUB_ACCESS_TOKEN }} LEONARDO_API_KEY: ${{ secrets.INTEG_LEONARDO_API_KEY }} LEONARDO_MODEL_ID: ${{ secrets.INTEG_LEONARDO_MODEL_ID }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 64816349a..6aebea667 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `AssistantTask` for running Assistants in Structures. - `GriptapeCloudAssistantDriver` for interacting with Griptape Cloud's Assistant API. - `OpenAiAssistantDriver` for interacting with OpenAI's Assistant API. +- New class of Drivers, Tool Drivers, for augmenting Tools with additional functionality. +- `GriptapeCloudToolDriver` for interacting with Griptape Cloud's Tool API. +- `BaseTool.tool_driver` for setting a Tool Driver on a Tool. The Tool will call `tool_driver.initialize_tool` after Tool initialization. ### Changed diff --git a/README.md b/README.md index 7edc661c2..d9ec93366 100644 --- a/README.md +++ b/README.md @@ -52,6 +52,7 @@ Drivers facilitate interactions with external resources and services: - 📈 **Observability Drivers** send trace and event data to observability platforms. - 📜 **Ruleset Drivers** load and apply rulesets from external sources. - 🗂️ **File Manager Drivers** handle file operations on local and remote storage. +- 🔨 **Tool Drivers** augment Tools with additional functionality. ### 🚂 Engines @@ -217,7 +218,7 @@ Run `make check` to run all code checks locally. ### Griptape Extensions -Griptape's extensibility allows anyone to develop and distribute functionality independently. +Griptape's extensibility allows anyone to develop and distribute functionality independently. All new integrations, including Tools, Drivers, Tasks, etc., should initially be developed as extensions and then can be upstreamed into Griptape core if discussed and approved. The [Griptape Extension Template](https://github.com/griptape-ai/griptape-extension-template) provides the recommended structure, step-by-step instructions, basic automation, and usage examples for new integrations. diff --git a/docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_1.py b/docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_1.py new file mode 100644 index 000000000..e6bb4952e --- /dev/null +++ b/docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_1.py @@ -0,0 +1,22 @@ +import os + +from griptape.drivers import GriptapeCloudToolDriver +from griptape.structures import Agent +from griptape.tools import BaseTool + + +class RandomNumberGeneratorTool(BaseTool): ... + + +tool = RandomNumberGeneratorTool() +driver = GriptapeCloudToolDriver( + api_key=os.environ["GT_CLOUD_API_KEY"], + tool_id=os.environ["GT_CLOUD_TOOL_ID"], +) +driver.initialize_tool(tool) + +agent = Agent( + tools=[tool], +) + +agent.run("Generate a random number to 5 decimal places.") diff --git a/docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_2.py b/docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_2.py new file mode 100644 index 000000000..6ae99a015 --- /dev/null +++ b/docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_2.py @@ -0,0 +1,40 @@ +import os +import random + +import schema + +from griptape.artifacts.text_artifact import TextArtifact +from griptape.drivers import GriptapeCloudToolDriver +from griptape.structures import Agent +from griptape.tools.base_tool import BaseTool +from griptape.utils.decorators import activity + + +class ExtendedRandomNumberGeneratorTool(BaseTool): + @activity( + config={ + "description": "Generates a random number in a range.", + "schema": schema.Schema( + { + "min": schema.Or(float, int), + "max": schema.Or(float, int), + } + ), + } + ) + def rand_range(self, values: dict) -> TextArtifact: + return TextArtifact(random.uniform(values["min"], values["max"])) + + +agent = Agent( + tools=[ + ExtendedRandomNumberGeneratorTool( + tool_driver=GriptapeCloudToolDriver( + api_key=os.environ["GT_CLOUD_API_KEY"], + tool_id=os.environ["GT_CLOUD_TOOL_ID"], + ) + ) + ], +) + +agent.run("Generate a random number between 1 and 100.") diff --git a/docs/griptape-framework/drivers/src/tool_driver_1.py b/docs/griptape-framework/drivers/src/tool_driver_1.py new file mode 100644 index 000000000..c541f01c5 --- /dev/null +++ b/docs/griptape-framework/drivers/src/tool_driver_1.py @@ -0,0 +1,22 @@ +import os + +from griptape.drivers import GriptapeCloudToolDriver +from griptape.structures import Agent +from griptape.tools import BaseTool + + +class RandomNumberGeneratorTool(BaseTool): ... + + +agent = Agent( + tools=[ + RandomNumberGeneratorTool( + tool_driver=GriptapeCloudToolDriver( + api_key=os.environ["GT_CLOUD_API_KEY"], + tool_id=os.environ["GT_CLOUD_TOOL_ID"], + ) + ) + ], +) + +agent.run("Generate a random number to 5 decimal places.") diff --git a/docs/griptape-framework/drivers/tool-drivers.md b/docs/griptape-framework/drivers/tool-drivers.md new file mode 100644 index 000000000..51fa1f57a --- /dev/null +++ b/docs/griptape-framework/drivers/tool-drivers.md @@ -0,0 +1,32 @@ +--- +search: + boost: 2 +--- + +## Overview + +Tool Drivers can be used to augment [Tools](../../griptape-tools/index.md) with additional functionality. + +You can use Tool Drivers with Tools: + +```python +--8<-- "docs/griptape-framework/drivers/src/tool_driver_1.py" +``` + +In this example, the `RandomNumberGeneratorTool`'s schema is loaded from Griptape Cloud. Additionally, when the `RandomNumberGeneratorTool` is executed, it is run in the Cloud rather than locally. + +## Tool Drivers + +### Griptape Cloud + +The [GriptapeCloudToolDriver](../../reference/griptape/drivers/tool/griptape_cloud_tool_driver.md) uses Griptape Cloud's hosted Tools service to modify the Tool with the hosted Tool's schema and execution logic. + +```python +--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_1.py" +``` + +The Driver will only overwrite activities present in the hosted Tool, meaning you can mix and match local and hosted activities. + +```python +--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_tool_driver_2.py" +``` diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index 36c153f7d..da7f6a93b 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -133,6 +133,8 @@ from .assistant.base_assistant_driver import BaseAssistantDriver from .assistant.griptape_cloud_assistant_driver import GriptapeCloudAssistantDriver from .assistant.openai_assistant_driver import OpenAiAssistantDriver +from .tool.base_tool_driver import BaseToolDriver +from .tool.griptape_cloud_tool_driver import GriptapeCloudToolDriver __all__ = [ "BasePromptDriver", @@ -240,4 +242,6 @@ "BaseAssistantDriver", "GriptapeCloudAssistantDriver", "OpenAiAssistantDriver", + "BaseToolDriver", + "GriptapeCloudToolDriver", ] diff --git a/griptape/drivers/tool/__init__.py b/griptape/drivers/tool/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/griptape/drivers/tool/base_tool_driver.py b/griptape/drivers/tool/base_tool_driver.py new file mode 100644 index 000000000..3e14af610 --- /dev/null +++ b/griptape/drivers/tool/base_tool_driver.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING + +from attrs import define + +if TYPE_CHECKING: + from griptape.tools.base_tool import BaseTool + + +@define +class BaseToolDriver(ABC): + @abstractmethod + def initialize_tool(self, tool: BaseTool) -> None: + """Performs some initialization operations on the Tool. This method is intended to mutate the Tool object. + + Args: + tool: Tool to initialize. + + """ + + ... diff --git a/griptape/drivers/tool/griptape_cloud_tool_driver.py b/griptape/drivers/tool/griptape_cloud_tool_driver.py new file mode 100644 index 000000000..7e78dbab4 --- /dev/null +++ b/griptape/drivers/tool/griptape_cloud_tool_driver.py @@ -0,0 +1,136 @@ +from __future__ import annotations + +import os +from types import MethodType +from typing import TYPE_CHECKING, Any, Callable +from urllib.parse import urljoin + +import requests +from attrs import Factory, define, field +from schema import Literal, Optional, Schema + +from griptape.artifacts import BaseArtifact, TextArtifact +from griptape.drivers.tool.base_tool_driver import BaseToolDriver +from griptape.utils.decorators import activity + +if TYPE_CHECKING: + from griptape.tools.base_tool import BaseTool + + +@define +class GriptapeCloudToolDriver(BaseToolDriver): + """Driver for interacting with tools hosted on the Griptape Cloud. + + Attributes: + base_url: Base URL of the Griptape Cloud. + api_key: API key for the Griptape Cloud. + tool_id: ID of the tool to interact with. + headers: Headers to use for requests. + """ + + base_url: str = field(default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai"))) + api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) + tool_id: str = field(kw_only=True) + headers: dict = field( + default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), + init=False, + ) + + def initialize_tool(self, tool: BaseTool) -> None: + schema = self._get_schema() + tool_name, activity_schemas = self._parse_schema(schema) + + if tool.name == tool.__class__.__name__: + tool.name = tool_name + + for activity_name, (description, activity_schema) in activity_schemas.items(): + activity_handler = self._create_activity_handler(activity_name, description, activity_schema) + + setattr(tool, activity_name, MethodType(activity_handler, self)) + + def _get_schema(self) -> dict: + url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/openapi") + response = requests.get(url, headers=self.headers) + + response.raise_for_status() + + return response.json() + + def _parse_schema(self, schema: dict) -> tuple[str, dict[str, tuple[str, Schema]]]: + """Parses an openapi schema into a dictionary of activity names and their respective descriptions + schemas.""" + activities = {} + + name = schema.get("info", {}).get("title") + + for path, path_info in schema.get("paths", {}).items(): + if not path.startswith("/activities"): + continue + for method, method_info in path_info.items(): + if "post" in method.lower(): + activity_name = method_info["operationId"] + description = method_info.get("description", "") + + activity_schema = self.__extract_schema_from_ref( + schema, + method_info.get("requestBody", {}) + .get("content", {}) + .get("application/json", {}) + .get("schema", {}), + ) + + activities[activity_name] = (description, activity_schema) + + return name, activities + + def __extract_schema_from_ref(self, schema: dict, schema_ref: dict) -> Schema: + """Extracts a schema from a $ref if present, resolving it into native schema properties.""" + if "$ref" in schema_ref: + # Resolve the reference and retrieve the schema data + ref_path = schema_ref["$ref"].split("/")[-1] + schema_data = schema["components"]["schemas"].get(ref_path, {}) + else: + # Use the provided schema directly if no $ref is found + schema_data = schema_ref + + # Convert the schema_data dictionary into a Schema with its properties + properties = {} + for prop, prop_info in schema_data.get("properties", {}).items(): + prop_type = prop_info.get("type", "string") + prop_description = prop_info.get("description", "") + schema_prop = Literal(prop, description=prop_description) + is_optional = prop not in schema_data.get("required", []) + + if is_optional: + schema_prop = Optional(schema_prop) + + properties[schema_prop] = self._map_openapi_type_to_python(prop_type) + + return Schema(properties) + + def _map_openapi_type_to_python(self, openapi_type: str) -> type: + """Maps OpenAPI types to native Python types.""" + type_mapping = {"string": str, "integer": int, "boolean": bool, "number": float, "array": list, "object": dict} + + return type_mapping.get(openapi_type, str) + + def _create_activity_handler(self, activity_name: str, description: str, activity_schema: Schema) -> Callable: + """Creates an activity handler method for the tool.""" + + @activity(config={"name": activity_name, "description": description, "schema": activity_schema}) + def activity_handler(_: BaseTool, values: dict) -> Any: + return self._run_activity(activity_name, values) + + return activity_handler + + def _run_activity(self, activity_name: str, params: dict) -> BaseArtifact: + """Runs an activity on the tool with the provided parameters.""" + url = urljoin(self.base_url, f"/api/tools/{self.tool_id}/activities/{activity_name}") + + response = requests.post(url, json=params, headers=self.headers) + + response.raise_for_status() + + try: + return BaseArtifact.from_dict(response.json()) + except ValueError: + return TextArtifact(response.text) diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 4a049752b..ab312c436 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -168,6 +168,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: BasePromptDriver, BaseRulesetDriver, BaseTextToSpeechDriver, + BaseToolDriver, BaseVectorStoreDriver, ) from griptape.events import EventListener @@ -192,6 +193,7 @@ def _resolve_types(cls, attrs_cls: type) -> None: "BaseConversationMemoryDriver": BaseConversationMemoryDriver, "BaseRulesetDriver": BaseRulesetDriver, "BaseImageGenerationDriver": BaseImageGenerationDriver, + "BaseToolDriver": BaseToolDriver, "BaseArtifact": BaseArtifact, "PromptStack": PromptStack, "EventListener": EventListener, diff --git a/griptape/tools/base_tool.py b/griptape/tools/base_tool.py index 560d79250..02bdfa4e4 100644 --- a/griptape/tools/base_tool.py +++ b/griptape/tools/base_tool.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from griptape.common import ToolAction + from griptape.drivers import BaseToolDriver from griptape.memory import TaskMemory from griptape.tasks import ActionsSubtask @@ -56,6 +57,7 @@ class BaseTool(ActivityMixin, SerializableMixin, RunnableMixin["BaseTool"], ABC) dependencies_install_directory: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) verbose: bool = field(default=False, kw_only=True, metadata={"serializable": True}) off_prompt: bool = field(default=False, kw_only=True, metadata={"serializable": True}) + tool_driver: BaseToolDriver = field(default=None, kw_only=True) def __attrs_post_init__(self) -> None: if ( @@ -65,6 +67,9 @@ def __attrs_post_init__(self) -> None: ): self.install_dependencies(os.environ.copy()) + if self.tool_driver is not None: + self.tool_driver.initialize_tool(self) + @output_memory.validator # pyright: ignore[reportAttributeAccessIssue] def validate_output_memory(self, _: Attribute, output_memory: dict[str, Optional[list[TaskMemory]]]) -> None: if output_memory: diff --git a/griptape/utils/decorators.py b/griptape/utils/decorators.py index 3eef6d8d0..3a3bbfd28 100644 --- a/griptape/utils/decorators.py +++ b/griptape/utils/decorators.py @@ -2,13 +2,14 @@ import functools import inspect -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, OrderedDict, cast import schema from schema import Schema CONFIG_SCHEMA = Schema( { + schema.Optional("name"): str, "description": str, schema.Optional("schema"): lambda data: isinstance(data, (Schema, Callable)), } @@ -28,7 +29,7 @@ def decorator(func: Callable) -> Any: def wrapper(self: Any, params: dict) -> Any: return func(self, **_build_kwargs(func, params)) - setattr(wrapper, "name", func.__name__) + setattr(wrapper, "name", validated_config.get("name", func.__name__)) setattr(wrapper, "config", validated_config) setattr(wrapper, "is_activity", True) @@ -58,8 +59,8 @@ def lazy_attr(self: Any, value: Any) -> None: def _build_kwargs(func: Callable, params: dict) -> dict: - func_params = inspect.signature(func).parameters.copy() - func_params.pop("self") + func_params = cast(OrderedDict, inspect.signature(func).parameters.copy()) + func_params.popitem(last=False) kwarg_var = None for param in func_params.values(): diff --git a/mkdocs.yml b/mkdocs.yml index a8cc094e3..b8cbe6822 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -131,6 +131,7 @@ nav: - Observability Drivers: "griptape-framework/drivers/observability-drivers.md" - Ruleset Drivers: "griptape-framework/drivers/ruleset-drivers.md" - File Manager Drivers: "griptape-framework/drivers/file-manager-drivers.md" + - Tool Drivers: "griptape-framework/drivers/tool-drivers.md" - Data: - Overview: "griptape-framework/data/index.md" - Artifacts: "griptape-framework/data/artifacts.md" diff --git a/tests/unit/drivers/tool/__init__.py b/tests/unit/drivers/tool/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/drivers/tool/test_griptape_cloud_tool_driver.py b/tests/unit/drivers/tool/test_griptape_cloud_tool_driver.py new file mode 100644 index 000000000..c7d410f30 --- /dev/null +++ b/tests/unit/drivers/tool/test_griptape_cloud_tool_driver.py @@ -0,0 +1,375 @@ +import json + +import pytest +import schema + +from griptape.artifacts.text_artifact import TextArtifact +from griptape.drivers import GriptapeCloudToolDriver +from tests.unit.drivers.prompt.test_base_prompt_driver import MockTool + +MOCK_SCHEMA = { + "openapi": "3.1.0", + "info": {"title": "DataProcessor", "version": "0.2.0"}, + "servers": [{"url": "https://cloud.griptape.ai/api/tools/1"}], + "paths": { + "/activities/processString": { + "post": { + "tags": ["Activities"], + "summary": "Process String", + "description": "Processes a string input", + "operationId": "processString", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/StringInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "String processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processNumber": { + "post": { + "tags": ["Activities"], + "summary": "Process Number", + "description": "Processes a number input", + "operationId": "processNumber", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/NumberInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Number processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processBoolean": { + "post": { + "tags": ["Activities"], + "summary": "Process Boolean", + "description": "Processes a boolean input", + "operationId": "processBoolean", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BooleanInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Boolean processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processArray": { + "post": { + "tags": ["Activities"], + "summary": "Process Array", + "description": "Processes an array input", + "operationId": "processArray", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/ArrayInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Array processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processObject": { + "post": { + "tags": ["Activities"], + "summary": "Process Object", + "description": "Processes an object input", + "operationId": "processObject", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/ObjectInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Object processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processRequired": { + "post": { + "tags": ["Activities"], + "summary": "Process Required", + "description": "Processes a required input", + "operationId": "processRequired", + "requestBody": { + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/RequiredInput"}}}, + "required": True, + }, + "responses": { + "200": { + "description": "Required processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processNoRef": { + "post": { + "tags": ["Activities"], + "summary": "Process No Ref", + "description": "Processes a no ref input", + "operationId": "processNoRef", + "requestBody": { + "content": { + "application/json": { + "schema": { + "properties": { + "text": {"type": "string", "title": "Text", "description": "The string to process"}, + }, + "type": "object", + "title": "StringInput", + } + } + }, + "required": True, + }, + "responses": { + "200": { + "description": "Required processed", + "content": {"application/json": {"schema": {"$ref": "#/components/schemas/BaseArtifact"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + "/activities/processWrongMethod": {"get": {}}, + "/openapi": { + "get": { + "tags": ["OpenAPI"], + "summary": "OpenAPI Specification", + "description": "Get the OpenAPI specification for this tool", + "operationId": "openapi", + "responses": { + "200": { + "description": "Successful Response", + "content": {"application/json": {"schema": {"type": "object", "title": "Response Openapi"}}}, + } + }, + "security": [{"bearerAuth": []}], + } + }, + }, + "components": { + "schemas": { + "BaseArtifact": { + "properties": { + "value": {"title": "Value", "description": "The return value of the activity"}, + "type": {"type": "string", "title": "Type", "description": "The type of the return value"}, + }, + "type": "object", + "required": ["value", "type"], + "title": "BaseArtifact", + }, + "ErrorResponse": { + "properties": { + "error": {"type": "string", "title": "Error", "description": "The return value of the activity"} + }, + "type": "object", + "required": ["error"], + "title": "ErrorResponse", + }, + "StringInput": { + "properties": { + "text": {"type": "string", "title": "Text", "description": "The string to process"}, + }, + "type": "object", + "title": "StringInput", + }, + "NumberInput": { + "properties": { + "number": {"type": "number", "title": "Number", "description": "The number to process"}, + }, + "type": "object", + "title": "NumberInput", + }, + "BooleanInput": { + "properties": { + "flag": {"type": "boolean", "title": "Flag", "description": "The boolean to process"}, + }, + "type": "object", + "title": "BooleanInput", + }, + "ArrayInput": { + "properties": { + "items": { + "type": "array", + "title": "Items", + "description": "An array of numbers", + "items": {"type": "number"}, + } + }, + "type": "object", + "title": "ArrayInput", + }, + "ObjectInput": { + "properties": { + "data": { + "type": "object", + "title": "Data", + "description": "An object containing key-value pairs", + "additionalProperties": {"type": "string"}, + } + }, + "type": "object", + "title": "ObjectInput", + }, + "RequiredInput": { + "properties": { + "required": { + "type": "string", + "title": "Required", + "description": "A required input field", + } + }, + "type": "object", + "required": ["required"], + "title": "RequiredInput", + }, + }, + "securitySchemes": {"bearerAuth": {"type": "http", "scheme": "bearer"}}, + }, +} + + +class TestGriptapeCloudToolTool: + @pytest.fixture() + def mock_schema(self): + return MOCK_SCHEMA + + @pytest.fixture(autouse=True) + def _mock_requests_get(self, mocker, mock_schema): + mock_get = mocker.patch("requests.get") + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = mock_schema + + @pytest.fixture(autouse=True) + def _mock_requests_post(self, mocker): + mock_get = mocker.patch("requests.post") + mock_get.return_value.status_code = 200 + mock_get.return_value.json.return_value = TextArtifact("foo").to_dict() + + def test_init(self): + tool = MockTool(tool_driver=GriptapeCloudToolDriver(tool_id="tool_id", api_key="foo")) + + # Define expected activity details for each method + expected_activities = { + "processString": { + "name": "processString", + "description": "Processes a string input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("text", description="The string to process")): str} + ), + }, + "processNumber": { + "name": "processNumber", + "description": "Processes a number input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("number", description="The number to process")): float} + ), + }, + "processBoolean": { + "name": "processBoolean", + "description": "Processes a boolean input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("flag", description="The boolean to process")): bool} + ), + }, + "processArray": { + "name": "processArray", + "description": "Processes an array input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("items", description="An array of numbers")): list} + ), + }, + "processObject": { + "name": "processObject", + "description": "Processes an object input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("data", description="An object containing key-value pairs")): dict} + ), + }, + "processRequired": { + "name": "processRequired", + "description": "Processes a required input", + "schema": schema.Schema({schema.Literal("required", description="A required input field"): str}), + }, + "processNoRef": { + "name": "processNoRef", + "description": "Processes a no ref input", + "schema": schema.Schema( + {schema.Optional(schema.Literal("text", description="The string to process")): str} + ), + }, + } + + for activity_name, details in expected_activities.items(): + assert hasattr(tool, activity_name), f"Method {activity_name} does not exist in the tool." + activity = getattr(tool, activity_name) + + assert getattr(activity, "name") == details["name"] + assert getattr(activity, "config")["name"] == details["name"] + assert getattr(activity, "config")["description"] == details["description"] + assert getattr(activity, "config")["schema"].json_schema("Schema") == details["schema"].json_schema( + "Schema" + ) + + assert getattr(activity, "is_activity") is True + + def test_multiple_init(self, mock_schema): + tool_1 = MockTool(tool_driver=GriptapeCloudToolDriver(tool_id="tool_id_1", api_key="foo")) + mock_schema["paths"]["/activities/processString"]["post"]["description"] = "new description" + tool_2 = MockTool(tool_driver=GriptapeCloudToolDriver(tool_id="tool_id_2", api_key="foo")) + + assert getattr(tool_1, "processString") != getattr(tool_2, "processString") + assert getattr(tool_1, "processString").config["description"] == "Processes a string input" + assert getattr(tool_2, "processString").config["description"] == "new description" + + def test_run_activity(self): + tool = MockTool(tool_driver=GriptapeCloudToolDriver(tool_id="tool_id", api_key="foo")) + response = tool.processString({"text": "foo"}) # pyright: ignore[reportAttributeAccessIssue] + + assert response.value == "foo" + + def test_name(self): + tool = MockTool(tool_driver=GriptapeCloudToolDriver(tool_id="tool_id", api_key="foo")) + + assert tool.name == "DataProcessor" + + tool = MockTool(tool_driver=GriptapeCloudToolDriver(tool_id="tool_id", api_key="foo"), name="CustomName") + + assert tool.name == "CustomName" + + def test_bad_artifact(self, mocker): + mock_get = mocker.patch("requests.post") + mock_get.return_value.status_code = 200 + return_value = {"type": "FooBarArtifact", "value": "foo"} + mock_get.return_value.json.return_value = return_value + mock_get.return_value.text = json.dumps(return_value) + tool = MockTool(tool_driver=GriptapeCloudToolDriver(tool_id="tool_id", api_key="foo")) + + result = tool.processString({"text": 1}) # pyright: ignore[reportAttributeAccessIssue] + assert isinstance(result, TextArtifact) + assert result.value == json.dumps(return_value)