From cb75a6e1631521b063d1f0f010d65d6b4c39bf06 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 30 Oct 2024 14:40:14 -0700 Subject: [PATCH] Add support for assistants to structure run driver --- MIGRATION.md | 4 - .../drivers/src/structure_run_drivers_2.py | 4 +- .../src/structure_run_tool_1.py | 4 +- .../griptape_cloud_structure_run_driver.py | 76 +++++++++---------- griptape/utils/stream.py | 12 +-- ...est_griptape_cloud_structure_run_driver.py | 2 +- 6 files changed, 41 insertions(+), 61 deletions(-) diff --git a/MIGRATION.md b/MIGRATION.md index d96d5ca0e..5e2d51f9d 100644 --- a/MIGRATION.md +++ b/MIGRATION.md @@ -39,7 +39,6 @@ Defaults.drivers_config = AnthropicDriversConfig( Many callables have been renamed for consistency. Update your code to use the new names using the [CHANGELOG.md](https://github.com/griptape-ai/griptape/pull/1275/files#diff-06572a96a58dc510037d5efa622f9bec8519bc1beab13c9f251e97e657a9d4ed) as the source of truth. - ### Removed `CompletionChunkEvent` `CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`. @@ -146,7 +145,6 @@ event_listener_driver.flush_events() The `observable` decorator has been moved to `griptape.common.decorators`. Update your imports accordingly. - #### Before ```python @@ -183,7 +181,6 @@ driver = HuggingFacePipelinePromptDriver( `execute` has been renamed to `run` in several places. Update your code accordingly. - #### Before ```python @@ -298,7 +295,6 @@ pip install griptape[drivers-prompt-huggingface-hub] pip install torch ``` - ### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types `CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`. diff --git a/docs/griptape-framework/drivers/src/structure_run_drivers_2.py b/docs/griptape-framework/drivers/src/structure_run_drivers_2.py index bec40c6ee..dc51d355a 100644 --- a/docs/griptape-framework/drivers/src/structure_run_drivers_2.py +++ b/docs/griptape-framework/drivers/src/structure_run_drivers_2.py @@ -7,7 +7,7 @@ base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"] api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"] -structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"] +resource_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"] pipeline = Pipeline( @@ -32,7 +32,7 @@ driver=GriptapeCloudStructureRunDriver( base_url=base_url, api_key=api_key, - structure_id=structure_id, + resource_id=resource_id, ), ), ] diff --git a/docs/griptape-tools/official-tools/src/structure_run_tool_1.py b/docs/griptape-tools/official-tools/src/structure_run_tool_1.py index 575092ce6..0be48a602 100644 --- a/docs/griptape-tools/official-tools/src/structure_run_tool_1.py +++ b/docs/griptape-tools/official-tools/src/structure_run_tool_1.py @@ -6,14 +6,14 @@ base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"] api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"] -structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"] +resource_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"] structure_run_tool = StructureRunTool( description="RAG Expert Agent - Structure to invoke with natural language queries about the topic of Retrieval Augmented Generation", driver=GriptapeCloudStructureRunDriver( base_url=base_url, api_key=api_key, - structure_id=structure_id, + resource_id=resource_id, ), ) diff --git a/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py index cc83797d7..67fc118f5 100644 --- a/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py +++ b/griptape/drivers/structure_run/griptape_cloud_structure_run_driver.py @@ -1,38 +1,42 @@ from __future__ import annotations -import time -from typing import Any +import os +from typing import Any, Literal, Optional from urllib.parse import urljoin from attrs import Factory, define, field from griptape.artifacts import BaseArtifact, InfoArtifact from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver +from griptape.events import BaseEvent, EventBus @define class GriptapeCloudStructureRunDriver(BaseStructureRunDriver): - base_url: str = field(default="https://cloud.griptape.ai", kw_only=True) - api_key: str = field(kw_only=True) + base_url: str = field( + default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), + ) + api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY"))) headers: dict = field( default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), kw_only=True, ) - structure_id: str = field(kw_only=True) - structure_run_wait_time_interval: int = field(default=2, kw_only=True) - structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True) + resource_type: Literal["structure", "assistant"] = field(default="structure", kw_only=True) + resource_id: str = field(kw_only=True) + run_wait_time_interval: int = field(default=2, kw_only=True) + run_max_wait_time_attempts: int = field(default=20, kw_only=True) async_run: bool = field(default=False, kw_only=True) def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact: from requests import Response, post - url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs") + url = urljoin(self.base_url.strip("/"), f"/api/{self.resource_type}s/{self.resource_id}/runs") env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()] response: Response = post( url, - json={"args": [arg.value for arg in args], "env_vars": env_vars}, + json={"args": [arg.value for arg in args], "env_vars": env_vars, "stream": True}, headers=self.headers, ) response.raise_for_status() @@ -41,40 +45,30 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact: if self.async_run: return InfoArtifact("Run started successfully") else: - return self._get_structure_run_result(response_json["structure_run_id"]) - - def _get_structure_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact: - url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}") - - result = self._get_structure_run_result_attempt(url) - status = result["status"] - - wait_attempts = 0 - while ( - status not in ("SUCCEEDED", "FAILED", "ERROR", "CANCELLED") - and wait_attempts < self.structure_run_max_wait_time_attempts - ): - # wait - time.sleep(self.structure_run_wait_time_interval) - wait_attempts += 1 - result = self._get_structure_run_result_attempt(url) - status = result["status"] - - if wait_attempts >= self.structure_run_max_wait_time_attempts: - raise Exception(f"Failed to get Run result after {self.structure_run_max_wait_time_attempts} attempts.") - - if status != "SUCCEEDED": - raise Exception(f"Run failed with status: {status}") - - if "output" in result: - return BaseArtifact.from_dict(result["output"]) - else: - return InfoArtifact("No output found in response") - - def _get_structure_run_result_attempt(self, structure_run_url: str) -> Any: + return self._get_run_result(response_json[f"{self.resource_type}_run_id"]) + + def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact: + url = urljoin(self.base_url.strip("/"), f"/api/{self.resource_type}-runs/{structure_run_id}/events") + + result = self._get_run_result_attempt(url) + events = result["events"] + offset = result["offset"] + + output = None + while not output: + result = self._get_run_result_attempt(url, offset=offset) + events = result["events"] + offset = result["next_offset"] + for event in events: + EventBus.publish_event(BaseEvent.from_dict(event["payload"])) + if event["type"] == "FinishStructureRunEvent": + output = BaseArtifact.from_dict(event["payload"]["output_task_output"]) + return output + + def _get_run_result_attempt(self, structure_run_url: str, offset: int = 0) -> Any: from requests import Response, get - response: Response = get(structure_run_url, headers=self.headers) + response: Response = get(structure_run_url, headers=self.headers, params={"offset": offset}) response.raise_for_status() return response.json() diff --git a/griptape/utils/stream.py b/griptape/utils/stream.py index af8c65b3b..51078355f 100644 --- a/griptape/utils/stream.py +++ b/griptape/utils/stream.py @@ -5,7 +5,7 @@ from threading import Thread from typing import TYPE_CHECKING -from attrs import Attribute, Factory, define, field +from attrs import Factory, define, field from griptape.artifacts.text_artifact import TextArtifact from griptape.events import ( @@ -41,16 +41,6 @@ class Stream: structure: Structure = field() - @structure.validator # pyright: ignore[reportAttributeAccessIssue] - def validate_structure(self, _: Attribute, structure: Structure) -> None: - from griptape.tasks import PromptTask - - streaming_tasks = [ - task for task in structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream - ] - if not streaming_tasks: - raise ValueError("Structure does not have any streaming tasks, enable with stream=True") - _event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue())) def run(self, *args) -> Iterator[TextArtifact]: diff --git a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py index bbd6bac2b..01155d5c8 100644 --- a/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py +++ b/tests/unit/drivers/structure_run/test_griptape_cloud_structure_run_driver.py @@ -24,7 +24,7 @@ def driver(self, mocker, mock_requests_post): return GriptapeCloudStructureRunDriver( base_url="https://cloud-foo.griptape.ai", api_key="foo bar", - structure_id="1", + resource_id="1", env={"key": "value"}, structure_run_wait_time_interval=0, )