-
Notifications
You must be signed in to change notification settings - Fork 183
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
GriptapeCloudFileManagerDriver (#1267)
- Loading branch information
Showing
20 changed files
with
707 additions
and
133 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
--- | ||
search: | ||
boost: 2 | ||
--- | ||
|
||
## Overview | ||
|
||
File Manager Drivers can be used to load and save files with local or external file systems. | ||
|
||
You can use File Manager Drivers with Loaders: | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/file_manager_driver.py" | ||
``` | ||
|
||
Or use them independently as shown below for each driver: | ||
|
||
## File Manager Drivers | ||
|
||
### Griptape Cloud | ||
|
||
!!! info | ||
This driver requires the `drivers-file-manager-griptape-cloud` [extra](../index.md#extras). | ||
|
||
The [GriptapeCloudFileManagerDriver](../../reference/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.md) allows you to load and save files sourced from Griptape Cloud Asset and Bucket resources. | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py" | ||
``` | ||
|
||
### Local | ||
|
||
The [LocalFileManagerDriver](../../reference/griptape/drivers/file_manager/local_file_manager_driver.md) allows you to load and save files sourced from a local directory. | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/local_file_manager_driver.py" | ||
``` | ||
|
||
### Amazon S3 | ||
|
||
!!! info | ||
This driver requires the `drivers-file-manager-amazon-s3` [extra](../index.md#extras). | ||
|
||
The [LocalFile ManagerDriver](../../reference/griptape/drivers/file_manager/amazon_s3_file_manager_driver.md) allows you to load and save files sourced from an Amazon S3 bucket. | ||
|
||
```python | ||
--8<-- "docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py" | ||
``` |
24 changes: 24 additions & 0 deletions
24
docs/griptape-framework/drivers/src/amazon_s3_file_manager_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
import os | ||
|
||
import boto3 | ||
|
||
from griptape.drivers import AmazonS3FileManagerDriver | ||
|
||
amazon_s3_file_manager_driver = AmazonS3FileManagerDriver( | ||
bucket=os.environ["AMAZON_S3_BUCKET"], | ||
session=boto3.Session( | ||
region_name=os.environ["AWS_DEFAULT_REGION"], | ||
aws_access_key_id=os.environ["AWS_ACCESS_KEY_ID"], | ||
aws_secret_access_key=os.environ["AWS_SECRET_ACCESS_KEY"], | ||
), | ||
) | ||
|
||
# Download File | ||
file_contents = amazon_s3_file_manager_driver.load_file(os.environ["AMAZON_S3_KEY"]) | ||
|
||
print(file_contents) | ||
|
||
# Upload File | ||
response = amazon_s3_file_manager_driver.save_file(os.environ["AMAZON_S3_KEY"], file_contents.value) | ||
|
||
print(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
from griptape.drivers import LocalFileManagerDriver | ||
from griptape.loaders import TextLoader | ||
|
||
local_file_manager_driver = LocalFileManagerDriver() | ||
|
||
loader = TextLoader(file_manager_driver=local_file_manager_driver) | ||
text_artifact = loader.load("tests/resources/test.txt") | ||
|
||
print(text_artifact.value) |
18 changes: 18 additions & 0 deletions
18
docs/griptape-framework/drivers/src/griptape_cloud_file_manager_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import os | ||
|
||
from griptape.drivers import GriptapeCloudFileManagerDriver | ||
|
||
gtc_file_manager_driver = GriptapeCloudFileManagerDriver( | ||
api_key=os.environ["GT_CLOUD_API_KEY"], | ||
bucket_id=os.environ["GT_CLOUD_BUCKET_ID"], | ||
) | ||
|
||
# Download File | ||
file_contents = gtc_file_manager_driver.load_file(os.environ["GT_CLOUD_ASSET_NAME"]) | ||
|
||
print(file_contents) | ||
|
||
# Upload File | ||
response = gtc_file_manager_driver.save_file(os.environ["GT_CLOUD_ASSET_NAME"], file_contents.value) | ||
|
||
print(response) |
13 changes: 13 additions & 0 deletions
13
docs/griptape-framework/drivers/src/local_file_manager_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
from griptape.drivers import LocalFileManagerDriver | ||
|
||
local_file_manager_driver = LocalFileManagerDriver() | ||
|
||
# Download File | ||
file_contents = local_file_manager_driver.load_file("tests/resources/test.txt") | ||
|
||
print(file_contents) | ||
|
||
# Upload File | ||
response = local_file_manager_driver.save_file("tests/resources/test.txt", file_contents.value) | ||
|
||
print(response) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
153 changes: 153 additions & 0 deletions
153
griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
from typing import TYPE_CHECKING, Optional | ||
from urllib.parse import urljoin | ||
|
||
import requests | ||
from attrs import Attribute, Factory, define, field | ||
|
||
from griptape.drivers import BaseFileManagerDriver | ||
from griptape.utils import import_optional_dependency | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
if TYPE_CHECKING: | ||
from azure.storage.blob import BlobClient | ||
|
||
|
||
@define | ||
class GriptapeCloudFileManagerDriver(BaseFileManagerDriver): | ||
"""GriptapeCloudFileManagerDriver can be used to list, load, and save files as Assets in Griptape Cloud Buckets. | ||
Attributes: | ||
bucket_id: The ID of the Bucket to list, load, and save Assets in. If not provided, the driver will attempt to | ||
retrieve the ID from the environment variable `GT_CLOUD_BUCKET_ID`. | ||
workdir: The working directory. List, load, and save operations will be performed relative to this directory. | ||
base_url: The base URL of the Griptape Cloud API. Defaults to the value of the environment variable | ||
`GT_CLOUD_BASE_URL` or `https://cloud.griptape.ai`. | ||
api_key: The API key to use for authenticating with the Griptape Cloud API. If not provided, the driver will | ||
attempt to retrieve the API key from the environment variable `GT_CLOUD_API_KEY`. | ||
Raises: | ||
ValueError: If `api_key` is not provided, if `workdir` does not start with "/"", or invalid `bucket_id` and/or `bucket_name` value(s) are provided. | ||
""" | ||
|
||
bucket_id: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_BUCKET_ID")), kw_only=True) | ||
workdir: str = field(default="/", kw_only=True) | ||
base_url: str = field( | ||
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), | ||
) | ||
api_key: Optional[str] = field(default=Factory(lambda: os.getenv("GT_CLOUD_API_KEY"))) | ||
headers: dict = field( | ||
default=Factory(lambda self: {"Authorization": f"Bearer {self.api_key}"}, takes_self=True), | ||
init=False, | ||
) | ||
|
||
@workdir.validator # pyright: ignore[reportAttributeAccessIssue] | ||
def validate_workdir(self, _: Attribute, workdir: str) -> None: | ||
if not workdir.startswith("/"): | ||
raise ValueError(f"{self.__class__.__name__} requires 'workdir' to be an absolute path, starting with `/`") | ||
|
||
@api_key.validator # pyright: ignore[reportAttributeAccessIssue] | ||
def validate_api_key(self, _: Attribute, value: Optional[str]) -> str: | ||
if value is None: | ||
raise ValueError(f"{self.__class__.__name__} requires an API key") | ||
return value | ||
|
||
@bucket_id.validator # pyright: ignore[reportAttributeAccessIssue] | ||
def validate_bucket_id(self, _: Attribute, value: Optional[str]) -> str: | ||
if value is None: | ||
raise ValueError(f"{self.__class__.__name__} requires an Bucket ID") | ||
return value | ||
|
||
def __attrs_post_init__(self) -> None: | ||
try: | ||
self._call_api(method="get", path=f"/buckets/{self.bucket_id}").json() | ||
except requests.exceptions.HTTPError as e: | ||
if e.response.status_code == 404: | ||
raise ValueError(f"No Bucket found with ID: {self.bucket_id}") from e | ||
raise ValueError(f"Unexpected error when retrieving Bucket with ID: {self.bucket_id}") from e | ||
|
||
def try_list_files(self, path: str, postfix: str = "") -> list[str]: | ||
full_key = self._to_full_key(path) | ||
|
||
if not self._is_a_directory(full_key): | ||
raise NotADirectoryError | ||
|
||
data = {"prefix": full_key} | ||
if postfix: | ||
data["postfix"] = postfix | ||
# TODO: GTC SDK: Pagination | ||
list_assets_response = self._call_api( | ||
method="list", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=False | ||
).json() | ||
|
||
return [asset["name"] for asset in list_assets_response.get("assets", [])] | ||
|
||
def try_load_file(self, path: str) -> bytes: | ||
full_key = self._to_full_key(path) | ||
|
||
if self._is_a_directory(full_key): | ||
raise IsADirectoryError | ||
|
||
try: | ||
blob_client = self._get_blob_client(full_key=full_key) | ||
except requests.exceptions.HTTPError as e: | ||
if e.response.status_code == 404: | ||
raise FileNotFoundError from e | ||
raise e | ||
|
||
try: | ||
return blob_client.download_blob().readall() | ||
except import_optional_dependency("azure.core.exceptions").ResourceNotFoundError as e: | ||
raise FileNotFoundError from e | ||
|
||
def try_save_file(self, path: str, value: bytes) -> str: | ||
full_key = self._to_full_key(path) | ||
|
||
if self._is_a_directory(full_key): | ||
raise IsADirectoryError | ||
|
||
try: | ||
self._call_api(method="get", path=f"/buckets/{self.bucket_id}/assets/{full_key}", raise_for_status=True) | ||
except requests.exceptions.HTTPError as e: | ||
if e.response.status_code == 404: | ||
logger.info("Asset '%s' not found, attempting to create", full_key) | ||
data = {"name": full_key} | ||
self._call_api(method="put", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=True) | ||
else: | ||
raise e | ||
|
||
blob_client = self._get_blob_client(full_key=full_key) | ||
|
||
blob_client.upload_blob(data=value, overwrite=True) | ||
return f"buckets/{self.bucket_id}/assets/{full_key}" | ||
|
||
def _get_blob_client(self, full_key: str) -> BlobClient: | ||
url_response = self._call_api( | ||
method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True | ||
).json() | ||
sas_url = url_response["url"] | ||
return import_optional_dependency("azure.storage.blob").BlobClient.from_blob_url(blob_url=sas_url) | ||
|
||
def _get_url(self, path: str) -> str: | ||
path = path.lstrip("/") | ||
return urljoin(self.base_url, f"/api/{path}") | ||
|
||
def _call_api( | ||
self, method: str, path: str, json: Optional[dict] = None, *, raise_for_status: bool = True | ||
) -> requests.Response: | ||
res = requests.request(method, self._get_url(path), json=json, headers=self.headers) | ||
if raise_for_status: | ||
res.raise_for_status() | ||
return res | ||
|
||
def _is_a_directory(self, path: str) -> bool: | ||
return path == "" or path.endswith("/") | ||
|
||
def _to_full_key(self, path: str) -> str: | ||
path = path.lstrip("/") | ||
full_key = f"{self.workdir}/{path}" | ||
return full_key.lstrip("/") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.