Skip to content

Commit

Permalink
Add support for assistants to structure run driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 30, 2024
1 parent a78fc60 commit cb75a6e
Show file tree
Hide file tree
Showing 6 changed files with 41 additions and 61 deletions.
4 changes: 0 additions & 4 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ Defaults.drivers_config = AnthropicDriversConfig(

Many callables have been renamed for consistency. Update your code to use the new names using the [CHANGELOG.md](https://github.com/griptape-ai/griptape/pull/1275/files#diff-06572a96a58dc510037d5efa622f9bec8519bc1beab13c9f251e97e657a9d4ed) as the source of truth.


### Removed `CompletionChunkEvent`

`CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`.
Expand Down Expand Up @@ -146,7 +145,6 @@ event_listener_driver.flush_events()

The `observable` decorator has been moved to `griptape.common.decorators`. Update your imports accordingly.


#### Before

```python
Expand Down Expand Up @@ -183,7 +181,6 @@ driver = HuggingFacePipelinePromptDriver(

`execute` has been renamed to `run` in several places. Update your code accordingly.


#### Before

```python
Expand Down Expand Up @@ -298,7 +295,6 @@ pip install griptape[drivers-prompt-huggingface-hub]
pip install torch
```


### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types

`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"]
api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"]
structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"]
resource_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"]


pipeline = Pipeline(
Expand All @@ -32,7 +32,7 @@
driver=GriptapeCloudStructureRunDriver(
base_url=base_url,
api_key=api_key,
structure_id=structure_id,
resource_id=resource_id,
),
),
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

base_url = os.environ["GRIPTAPE_CLOUD_BASE_URL"]
api_key = os.environ["GRIPTAPE_CLOUD_API_KEY"]
structure_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"]
resource_id = os.environ["GRIPTAPE_CLOUD_STRUCTURE_ID"]

structure_run_tool = StructureRunTool(
description="RAG Expert Agent - Structure to invoke with natural language queries about the topic of Retrieval Augmented Generation",
driver=GriptapeCloudStructureRunDriver(
base_url=base_url,
api_key=api_key,
structure_id=structure_id,
resource_id=resource_id,
),
)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,38 +1,42 @@
from __future__ import annotations

import time
from typing import Any
import os
from typing import Any, Literal, Optional
from urllib.parse import urljoin

from attrs import Factory, define, field

from griptape.artifacts import BaseArtifact, InfoArtifact
from griptape.drivers.structure_run.base_structure_run_driver import BaseStructureRunDriver
from griptape.events import BaseEvent, EventBus


@define
class GriptapeCloudStructureRunDriver(BaseStructureRunDriver):
base_url: str = field(default="https://cloud.griptape.ai", kw_only=True)
api_key: str = field(kw_only=True)
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
)
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")))
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
kw_only=True,
)
structure_id: str = field(kw_only=True)
structure_run_wait_time_interval: int = field(default=2, kw_only=True)
structure_run_max_wait_time_attempts: int = field(default=20, kw_only=True)
resource_type: Literal["structure", "assistant"] = field(default="structure", kw_only=True)
resource_id: str = field(kw_only=True)
run_wait_time_interval: int = field(default=2, kw_only=True)
run_max_wait_time_attempts: int = field(default=20, kw_only=True)
async_run: bool = field(default=False, kw_only=True)

def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
from requests import Response, post

url = urljoin(self.base_url.strip("/"), f"/api/structures/{self.structure_id}/runs")
url = urljoin(self.base_url.strip("/"), f"/api/{self.resource_type}s/{self.resource_id}/runs")

env_vars = [{"name": key, "value": value, "source": "manual"} for key, value in self.env.items()]

response: Response = post(
url,
json={"args": [arg.value for arg in args], "env_vars": env_vars},
json={"args": [arg.value for arg in args], "env_vars": env_vars, "stream": True},
headers=self.headers,
)
response.raise_for_status()
Expand All @@ -41,40 +45,30 @@ def try_run(self, *args: BaseArtifact) -> BaseArtifact | InfoArtifact:
if self.async_run:
return InfoArtifact("Run started successfully")
else:
return self._get_structure_run_result(response_json["structure_run_id"])

def _get_structure_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact:
url = urljoin(self.base_url.strip("/"), f"/api/structure-runs/{structure_run_id}")

result = self._get_structure_run_result_attempt(url)
status = result["status"]

wait_attempts = 0
while (
status not in ("SUCCEEDED", "FAILED", "ERROR", "CANCELLED")
and wait_attempts < self.structure_run_max_wait_time_attempts
):
# wait
time.sleep(self.structure_run_wait_time_interval)
wait_attempts += 1
result = self._get_structure_run_result_attempt(url)
status = result["status"]

if wait_attempts >= self.structure_run_max_wait_time_attempts:
raise Exception(f"Failed to get Run result after {self.structure_run_max_wait_time_attempts} attempts.")

if status != "SUCCEEDED":
raise Exception(f"Run failed with status: {status}")

if "output" in result:
return BaseArtifact.from_dict(result["output"])
else:
return InfoArtifact("No output found in response")

def _get_structure_run_result_attempt(self, structure_run_url: str) -> Any:
return self._get_run_result(response_json[f"{self.resource_type}_run_id"])

def _get_run_result(self, structure_run_id: str) -> BaseArtifact | InfoArtifact:
url = urljoin(self.base_url.strip("/"), f"/api/{self.resource_type}-runs/{structure_run_id}/events")

result = self._get_run_result_attempt(url)
events = result["events"]
offset = result["offset"]

output = None
while not output:
result = self._get_run_result_attempt(url, offset=offset)
events = result["events"]
offset = result["next_offset"]
for event in events:
EventBus.publish_event(BaseEvent.from_dict(event["payload"]))
if event["type"] == "FinishStructureRunEvent":
output = BaseArtifact.from_dict(event["payload"]["output_task_output"])
return output

def _get_run_result_attempt(self, structure_run_url: str, offset: int = 0) -> Any:
from requests import Response, get

response: Response = get(structure_run_url, headers=self.headers)
response: Response = get(structure_run_url, headers=self.headers, params={"offset": offset})
response.raise_for_status()

return response.json()
12 changes: 1 addition & 11 deletions griptape/utils/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from threading import Thread
from typing import TYPE_CHECKING

from attrs import Attribute, Factory, define, field
from attrs import Factory, define, field

from griptape.artifacts.text_artifact import TextArtifact
from griptape.events import (
Expand Down Expand Up @@ -41,16 +41,6 @@ class Stream:

structure: Structure = field()

@structure.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_structure(self, _: Attribute, structure: Structure) -> None:
from griptape.tasks import PromptTask

streaming_tasks = [
task for task in structure.tasks if isinstance(task, PromptTask) and task.prompt_driver.stream
]
if not streaming_tasks:
raise ValueError("Structure does not have any streaming tasks, enable with stream=True")

_event_queue: Queue[BaseEvent] = field(default=Factory(lambda: Queue()))

def run(self, *args) -> Iterator[TextArtifact]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def driver(self, mocker, mock_requests_post):
return GriptapeCloudStructureRunDriver(
base_url="https://cloud-foo.griptape.ai",
api_key="foo bar",
structure_id="1",
resource_id="1",
env={"key": "value"},
structure_run_wait_time_interval=0,
)
Expand Down

0 comments on commit cb75a6e

Please sign in to comment.