Skip to content

Commit

Permalink
GriptapeCloudFileManagerDriver (#1267)
Browse files Browse the repository at this point in the history
  • Loading branch information
cjkindel authored Oct 21, 2024
1 parent aee0dc4 commit e6ffc90
Show file tree
Hide file tree
Showing 20 changed files with 707 additions and 133 deletions.
5 changes: 5 additions & 0 deletions .github/workflows/docs-integration-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,11 @@ jobs:
ASTRA_DB_APPLICATION_TOKEN: ${{ secrets.INTEG_ASTRA_DB_APPLICATION_TOKEN }}
TAVILY_API_KEY: ${{ secrets.INTEG_TAVILY_API_KEY }}
EXA_API_KEY: ${{ secrets.INTEG_EXA_API_KEY }}
AMAZON_S3_BUCKET: ${{ secrets.INTEG_AMAZON_S3_BUCKET }}
AMAZON_S3_KEY: ${{ secrets.INTEG_AMAZON_S3_KEY }}
GT_CLOUD_BUCKET_ID: ${{ secrets.INTEG_GT_CLOUD_BUCKET_ID }}
GT_CLOUD_ASSET_NAME: ${{ secrets.INTEG_GT_CLOUD_ASSET_NAME }}

services:
postgres:
image: ankane/pgvector:v0.5.0
Expand Down
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Exponential backoff to `BaseEventListenerDriver` for retrying failed event publishing.
- `BaseTask.task_outputs` to get a dictionary of all task outputs. This has been added to `Workflow.context` and `Pipeline.context`.
- `Chat.input_fn` for customizing the input to the Chat utility.
- `GriptapeCloudFileManagerDriver` for managing files on Griptape Cloud.
- `BaseFileManagerDriver.load_artifact()` & `BaseFileManagerDriver.save_artifact()` for loading & saving artifacts as files.

### Changed

Expand Down
48 changes: 48 additions & 0 deletions docs/griptape-framework/drivers/file-manager-drivers.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
---
search:
boost: 2
---

## Overview

File Manager Drivers can be used to load and save files with local or external file systems.

You can use File Manager Drivers with Loaders:

```python
--8<-- "docs/griptape-framework/drivers/src/file_manager_driver.py"
```

Or use them independently as shown below for each driver:

## File Manager Drivers

### Griptape Cloud

!!! info
This driver requires the `drivers-file-manager-griptape-cloud` [extra](../index.md#extras).

The [GriptapeCloudFileManagerDriver](../../reference/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.md) allows you to load and save files sourced from Griptape Cloud Asset and Bucket resources.

```python
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py"
```

### Local

The [LocalFileManagerDriver](../../reference/griptape/drivers/file_manager/local_file_manager_driver.md) allows you to load and save files sourced from a local directory.

```python
--8<-- "docs/griptape-framework/drivers/src/local_file_manager_driver.py"
```

### Amazon S3

!!! info
This driver requires the `drivers-file-manager-amazon-s3` [extra](../index.md#extras).

The [LocalFile ManagerDriver](../../reference/griptape/drivers/file_manager/amazon_s3_file_manager_driver.md) allows you to load and save files sourced from an Amazon S3 bucket.

```python
--8<-- "docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py"
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import boto3

from griptape.drivers import AmazonS3FileManagerDriver

amazon_s3_file_manager_driver = AmazonS3FileManagerDriver(
bucket=os.environ["AMAZON_S3_BUCKET"],
session=boto3.Session(
region_name=os.environ["AWS_DEFAULT_REGION"],
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"],
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"],
),
)

# Download File
file_contents = amazon_s3_file_manager_driver.load_file(os.environ["AMAZON_S3_KEY"])

print(file_contents)

# Upload File
response = amazon_s3_file_manager_driver.save_file(os.environ["AMAZON_S3_KEY"], file_contents.value)

print(response)
9 changes: 9 additions & 0 deletions docs/griptape-framework/drivers/src/file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from griptape.drivers import LocalFileManagerDriver
from griptape.loaders import TextLoader

local_file_manager_driver = LocalFileManagerDriver()

loader = TextLoader(file_manager_driver=local_file_manager_driver)
text_artifact = loader.load("tests/resources/test.txt")

print(text_artifact.value)
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import os

from griptape.drivers import GriptapeCloudFileManagerDriver

gtc_file_manager_driver = GriptapeCloudFileManagerDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
bucket_id=os.environ["GT_CLOUD_BUCKET_ID"],
)

# Download File
file_contents = gtc_file_manager_driver.load_file(os.environ["GT_CLOUD_ASSET_NAME"])

print(file_contents)

# Upload File
response = gtc_file_manager_driver.save_file(os.environ["GT_CLOUD_ASSET_NAME"], file_contents.value)

print(response)
13 changes: 13 additions & 0 deletions docs/griptape-framework/drivers/src/local_file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
from griptape.drivers import LocalFileManagerDriver

local_file_manager_driver = LocalFileManagerDriver()

# Download File
file_contents = local_file_manager_driver.load_file("tests/resources/test.txt")

print(file_contents)

# Upload File
response = local_file_manager_driver.save_file("tests/resources/test.txt", file_contents.value)

print(response)
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@
from .file_manager.base_file_manager_driver import BaseFileManagerDriver
from .file_manager.local_file_manager_driver import LocalFileManagerDriver
from .file_manager.amazon_s3_file_manager_driver import AmazonS3FileManagerDriver
from .file_manager.griptape_cloud_file_manager_driver import GriptapeCloudFileManagerDriver

from .rerank.base_rerank_driver import BaseRerankDriver
from .rerank.cohere_rerank_driver import CohereRerankDriver
Expand Down Expand Up @@ -230,6 +231,7 @@
"BaseFileManagerDriver",
"LocalFileManagerDriver",
"AmazonS3FileManagerDriver",
"GriptapeCloudFileManagerDriver",
"BaseRerankDriver",
"CohereRerankDriver",
"BaseRulesetDriver",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,12 @@ def try_load_file(self, path: str) -> bytes:
raise FileNotFoundError from e
raise e

def try_save_file(self, path: str, value: bytes) -> None:
def try_save_file(self, path: str, value: bytes) -> str:
full_key = self._to_full_key(path)
if self._is_a_directory(full_key):
raise IsADirectoryError
self.client.put_object(Bucket=self.bucket, Key=full_key, Body=value)
return f"s3://{self.bucket}/{full_key}"

def _to_full_key(self, path: str) -> str:
path = path.lstrip("/")
Expand Down
22 changes: 18 additions & 4 deletions griptape/drivers/file_manager/base_file_manager_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from attrs import define, field

from griptape.artifacts import BlobArtifact, InfoArtifact, TextArtifact
from griptape.artifacts import BaseArtifact, BlobArtifact, InfoArtifact, TextArtifact


@define
Expand Down Expand Up @@ -42,9 +42,23 @@ def save_file(self, path: str, value: bytes | str) -> InfoArtifact:
elif isinstance(value, (bytearray, memoryview)):
raise ValueError(f"Unsupported type: {type(value)}")

self.try_save_file(path, value)
location = self.try_save_file(path, value)

return InfoArtifact("Successfully saved file")
return InfoArtifact(f"Successfully saved file at: {location}")

@abstractmethod
def try_save_file(self, path: str, value: bytes) -> None: ...
def try_save_file(self, path: str, value: bytes) -> str: ...

def load_artifact(self, path: str) -> BaseArtifact:
response = self.try_load_file(path)
return BaseArtifact.from_json(
response.decode() if self.encoding is None else response.decode(encoding=self.encoding)
)

def save_artifact(self, path: str, artifact: BaseArtifact) -> InfoArtifact:
artifact_json = artifact.to_json()
value = artifact_json.encode() if self.encoding is None else artifact_json.encode(encoding=self.encoding)

location = self.try_save_file(path, value)

return InfoArtifact(f"Successfully saved artifact at: {location}")
153 changes: 153 additions & 0 deletions griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING, Optional
from urllib.parse import urljoin

import requests
from attrs import Attribute, Factory, define, field

from griptape.drivers import BaseFileManagerDriver
from griptape.utils import import_optional_dependency

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from azure.storage.blob import BlobClient


@define
class GriptapeCloudFileManagerDriver(BaseFileManagerDriver):
"""GriptapeCloudFileManagerDriver can be used to list, load, and save files as Assets in Griptape Cloud Buckets.
Attributes:
bucket_id: The ID of the Bucket to list, load, and save Assets in. If not provided, the driver will attempt to
retrieve the ID from the environment variable `GT_CLOUD_BUCKET_ID`.
workdir: The working directory. List, load, and save operations will be performed relative to this directory.
base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable
`GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`.
api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will
attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`.
Raises:
ValueError: If `api_key` is not provided, if `workdir` does not start with "/"", or invalid `bucket_id` and/or `bucket_name` value(s) are provided.
"""

bucket_id: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_BUCKET_ID")), kw_only=True)
workdir: str = field(default="/", kw_only=True)
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
)
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY")))
headers: dict = field(
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True),
init=False,
)

@workdir.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_workdir(self, _: Attribute, workdir: str) -> None:
if not workdir.startswith("/"):
raise ValueError(f"{self.__class__.__name__} requires 'workdir' to be an absolute path, starting with `/`")

@api_key.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str:
if value is None:
raise ValueError(f"{self.__class__.__name__} requires an API key")
return value

@bucket_id.validator # pyright: ignore[reportAttributeAccessIssue]
def validate_bucket_id(self, _: Attribute, value: Optional[str]) -> str:
if value is None:
raise ValueError(f"{self.__class__.__name__} requires an Bucket ID")
return value

def __attrs_post_init__(self) -> None:
try:
self._call_api(method="get", path=f"/buckets/{self.bucket_id}").json()
except requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
raise ValueError(f"No Bucket found with ID: {self.bucket_id}") from e
raise ValueError(f"Unexpected error when retrieving Bucket with ID: {self.bucket_id}") from e

def try_list_files(self, path: str, postfix: str = "") -> list[str]:
full_key = self._to_full_key(path)

if not self._is_a_directory(full_key):
raise NotADirectoryError

data = {"prefix": full_key}
if postfix:
data["postfix"] = postfix
# TODO: GTC SDK: Pagination
list_assets_response = self._call_api(
method="list", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=False
).json()

return [asset["name"] for asset in list_assets_response.get("assets", [])]

def try_load_file(self, path: str) -> bytes:
full_key = self._to_full_key(path)

if self._is_a_directory(full_key):
raise IsADirectoryError

try:
blob_client = self._get_blob_client(full_key=full_key)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
raise FileNotFoundError from e
raise e

try:
return blob_client.download_blob().readall()
except import_optional_dependency("azure.core.exceptions").ResourceNotFoundError as e:
raise FileNotFoundError from e

def try_save_file(self, path: str, value: bytes) -> str:
full_key = self._to_full_key(path)

if self._is_a_directory(full_key):
raise IsADirectoryError

try:
self._call_api(method="get", path=f"/buckets/{self.bucket_id}/assets/{full_key}", raise_for_status=True)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
logger.info("Asset '%s' not found, attempting to create", full_key)
data = {"name": full_key}
self._call_api(method="put", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=True)
else:
raise e

blob_client = self._get_blob_client(full_key=full_key)

blob_client.upload_blob(data=value, overwrite=True)
return f"buckets/{self.bucket_id}/assets/{full_key}"

def _get_blob_client(self, full_key: str) -> BlobClient:
url_response = self._call_api(
method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True
).json()
sas_url = url_response["url"]
return import_optional_dependency("azure.storage.blob").BlobClient.from_blob_url(blob_url=sas_url)

def _get_url(self, path: str) -> str:
path = path.lstrip("/")
return urljoin(self.base_url, f"/api/{path}")

def _call_api(
self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True
) -> requests.Response:
res = requests.request(method, self._get_url(path), json=json, headers=self.headers)
if raise_for_status:
res.raise_for_status()
return res

def _is_a_directory(self, path: str) -> bool:
return path == "" or path.endswith("/")

def _to_full_key(self, path: str) -> str:
path = path.lstrip("/")
full_key = f"{self.workdir}/{path}"
return full_key.lstrip("/")
3 changes: 2 additions & 1 deletion griptape/drivers/file_manager/local_file_manager_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,13 @@ def try_load_file(self, path: str) -> bytes:
raise IsADirectoryError
return Path(full_path).read_bytes()

def try_save_file(self, path: str, value: bytes) -> None:
def try_save_file(self, path: str, value: bytes) -> str:
full_path = self._full_path(path)
if self._is_dir(full_path):
raise IsADirectoryError
os.makedirs(os.path.dirname(full_path), exist_ok=True)
Path(full_path).write_bytes(value)
return full_path

def _full_path(self, path: str) -> str:
full_path = path if self.workdir is None else os.path.join(self.workdir, path.lstrip("/"))
Expand Down
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ nav:
- Web Search Drivers: "griptape-framework/drivers/web-search-drivers.md"
- Observability Drivers: "griptape-framework/drivers/observability-drivers.md"
- Ruleset Drivers: "griptape-framework/drivers/ruleset-drivers.md"
- File Manager Drivers: "griptape-framework/drivers/file-manager-drivers.md"
- Data:
- Overview: "griptape-framework/data/index.md"
- Artifacts: "griptape-framework/data/artifacts.md"
Expand Down
Loading

0 comments on commit e6ffc90

Please sign in to comment.