Skip to content

Commit

Permalink
Remove azure deps from Griptape Cloud File Manager Driver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Oct 30, 2024
1 parent a78fc60 commit af5ca11
Show file tree
Hide file tree
Showing 7 changed files with 51 additions and 134 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## Unreleased

### Changed

- Removed `azure-core` and `azure-storage-blob` dependencies.
- `GriptapeCloudFileManagerDriver` no longer requires `drivers-file-manager-griptape-cloud` extra.

## \[0.34.0\] - 2024-10-29

### Added
Expand Down
4 changes: 0 additions & 4 deletions MIGRATION.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ Defaults.drivers_config = AnthropicDriversConfig(

Many callables have been renamed for consistency. Update your code to use the new names using the [CHANGELOG.md](https://github.com/griptape-ai/griptape/pull/1275/files#diff-06572a96a58dc510037d5efa622f9bec8519bc1beab13c9f251e97e657a9d4ed) as the source of truth.


### Removed `CompletionChunkEvent`

`CompletionChunkEvent` has been removed. There is now `BaseChunkEvent` with children `TextChunkEvent` and `ActionChunkEvent`. `BaseChunkEvent` can replace `completion_chunk_event.token` by doing `str(base_chunk_event)`.
Expand Down Expand Up @@ -146,7 +145,6 @@ event_listener_driver.flush_events()

The `observable` decorator has been moved to `griptape.common.decorators`. Update your imports accordingly.


#### Before

```python
Expand Down Expand Up @@ -183,7 +181,6 @@ driver = HuggingFacePipelinePromptDriver(

`execute` has been renamed to `run` in several places. Update your code accordingly.


#### Before

```python
Expand Down Expand Up @@ -298,7 +295,6 @@ pip install griptape[drivers-prompt-huggingface-hub]
pip install torch
```


### `CsvLoader`, `DataframeLoader`, and `SqlLoader` return types

`CsvLoader`, `DataframeLoader`, and `SqlLoader` now return a `list[TextArtifact]` instead of `list[CsvRowArtifact]`.
Expand Down
3 changes: 0 additions & 3 deletions docs/griptape-framework/drivers/file-manager-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ Or use them independently as shown below for each driver:

### Griptape Cloud

!!! info
This driver requires the `drivers-file-manager-griptape-cloud` [extra](../index.md#extras).

The [GriptapeCloudFileManagerDriver](../../reference/griptape/drivers/file_manager/griptape_cloud_file_manager_driver.md) allows you to load and save files sourced from Griptape Cloud Asset and Bucket resources.

```python
Expand Down
37 changes: 16 additions & 21 deletions griptape/drivers/file_manager/griptape_cloud_file_manager_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,16 @@

import logging
import os
from typing import TYPE_CHECKING, Optional
from typing import Optional
from urllib.parse import urljoin

import requests
from attrs import Attribute, Factory, define, field

from griptape.drivers import BaseFileManagerDriver
from griptape.utils import import_optional_dependency

logger = logging.getLogger(__name__)

if TYPE_CHECKING:
from azure.storage.blob import BlobClient


@define
class GriptapeCloudFileManagerDriver(BaseFileManagerDriver):
Expand Down Expand Up @@ -79,7 +75,6 @@ def try_list_files(self, path: str, postfix: str = "") -> list[str]:
data = {"prefix": full_key}
if postfix:
data["postfix"] = postfix
# TODO: GTC SDK: Pagination
list_assets_response = self._call_api(
method="list", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=False
).json()
Expand All @@ -93,17 +88,15 @@ def try_load_file(self, path: str) -> bytes:
raise IsADirectoryError

try:
blob_client = self._get_blob_client(full_key=full_key)
sas_url, headers = self._get_asset_url(full_key)
response = requests.get(sas_url, headers=headers)
response.raise_for_status()
return response.content
except requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
raise FileNotFoundError from e
raise e

try:
return blob_client.download_blob().readall()
except import_optional_dependency("azure.core.exceptions").ResourceNotFoundError as e:
raise FileNotFoundError from e

def try_save_file(self, path: str, value: bytes) -> str:
full_key = self._to_full_key(path)

Expand All @@ -114,23 +107,25 @@ def try_save_file(self, path: str, value: bytes) -> str:
self._call_api(method="get", path=f"/buckets/{self.bucket_id}/assets/{full_key}", raise_for_status=True)
except requests.exceptions.HTTPError as e:
if e.response.status_code == 404:
logger.info("Asset '%s' not found, attempting to create", full_key)
data = {"name": full_key}
self._call_api(method="put", path=f"/buckets/{self.bucket_id}/assets", json=data, raise_for_status=True)
self._call_api(
method="put",
path=f"/buckets/{self.bucket_id}/assets",
json={"name": full_key},
raise_for_status=True,
)
else:
raise e
sas_url, headers = self._get_asset_url(full_key)
response = requests.put(sas_url, data=value, headers=headers)
response.raise_for_status()

blob_client = self._get_blob_client(full_key=full_key)

blob_client.upload_blob(data=value, overwrite=True)
return f"buckets/{self.bucket_id}/assets/{full_key}"

def _get_blob_client(self, full_key: str) -> BlobClient:
def _get_asset_url(self, full_key: str) -> tuple[str, dict]:
url_response = self._call_api(
method="post", path=f"/buckets/{self.bucket_id}/asset-urls/{full_key}", raise_for_status=True
).json()
sas_url = url_response["url"]
return import_optional_dependency("azure.storage.blob").BlobClient.from_blob_url(blob_url=sas_url)
return url_response["url"], url_response.get("headers", {})

def _get_url(self, path: str) -> str:
path = path.lstrip("/")
Expand Down
55 changes: 2 additions & 53 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 0 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,6 @@ opentelemetry-exporter-otlp-proto-http = {version = "^1.25.0", optional = true}
diffusers = {version = "^0.31.0", optional = true}
tavily-python = {version = "^0.5.0", optional = true}
exa-py = {version = "^1.1.4", optional = true}
azure-core = "^1.31.0"
azure-storage-blob = "^12.23.1"

# loaders
pandas = {version = "^1.3", optional = true}
Expand Down Expand Up @@ -149,7 +147,6 @@ drivers-observability-datadog = [
drivers-image-generation-huggingface = ["diffusers", "pillow"]

drivers-file-manager-amazon-s3 = ["boto3"]
drivers-file-manager-griptape-cloud = ["azure-core", "azure-storage-blob"]

loaders-pdf = ["pypdf"]
loaders-image = ["pillow"]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import pytest
import requests
from azure.core.exceptions import ResourceNotFoundError


class TestGriptapeCloudFileManagerDriver:
Expand Down Expand Up @@ -98,19 +97,18 @@ def test_try_list_files_not_directory(self, mocker, driver):
driver.try_list_files("foo")

def test_try_load_file(self, mocker, driver):
mock_response = mocker.Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"url": "https://foo.bar"}
mocker.patch("requests.request", return_value=mock_response)
mock_url_response = mocker.Mock()
mock_url_response.status_code = 200
mock_url_response.json.return_value = {"url": "https://foo.bar"}
mocker.patch("requests.request", return_value=mock_url_response)

mock_bytes = b"bytes"
mock_blob_client = mocker.Mock()
mock_blob_client.download_blob.return_value.readall.return_value = mock_bytes
mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client)
mock_file_response = mocker.Mock()
mock_file_response.status_code = 200
mock_file_response.content = b"bytes"
mocker.patch("requests.get", return_value=mock_file_response)

response = driver.try_load_file("foo")

assert response == mock_bytes
assert response == b"bytes"

def test_try_load_file_directory(self, mocker, driver):
mock_response = mocker.Mock()
Expand All @@ -121,42 +119,29 @@ def test_try_load_file_directory(self, mocker, driver):
with pytest.raises(IsADirectoryError):
driver.try_load_file("foo/")

def test_try_load_file_sas_404(self, mocker, driver):
def test_try_load_file_asset_url_404(self, mocker, driver):
mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=404)))

with pytest.raises(FileNotFoundError):
driver.try_load_file("foo")

def test_try_load_file_sas_500(self, mocker, driver):
def test_try_load_file_asset_url_500(self, mocker, driver):
mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=500)))

with pytest.raises(requests.exceptions.HTTPError):
driver.try_load_file("foo")

def test_try_load_file_blob_404(self, mocker, driver):
mock_response = mocker.Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"url": "https://foo.bar"}
mocker.patch("requests.request", return_value=mock_response)

mock_blob_client = mocker.Mock()
mock_blob_client.download_blob.side_effect = ResourceNotFoundError()
mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client)

with pytest.raises(FileNotFoundError):
driver.try_load_file("foo")

def test_try_save_files(self, mocker, driver):
mock_response = mocker.Mock()
mock_response.status_code = 200
mock_response.json.return_value = {"url": "https://foo.bar"}
mocker.patch("requests.request", return_value=mock_response)
def test_try_save_file(self, mocker, driver):
mock_url_response = mocker.Mock()
mock_url_response.status_code = 200
mock_url_response.json.return_value = {"url": "https://foo.bar"}
mocker.patch("requests.request", return_value=mock_url_response)

mock_blob_client = mocker.Mock()
mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client)
mock_put_response = mocker.Mock()
mock_put_response.status_code = 200
mocker.patch("requests.put", return_value=mock_put_response)

response = driver.try_save_file("foo", b"value")

assert response == "buckets/1/assets/foo"

def test_try_save_file_directory(self, mocker, driver):
Expand All @@ -168,24 +153,17 @@ def test_try_save_file_directory(self, mocker, driver):
with pytest.raises(IsADirectoryError):
driver.try_save_file("foo/", b"value")

def test_try_save_file_sas_404(self, mocker, driver):
mock_response = mocker.Mock()
mock_response.json.return_value = {"url": "https://foo.bar"}
mock_response.raise_for_status.side_effect = [
requests.exceptions.HTTPError(response=mock.Mock(status_code=404)),
None,
None,
]
mocker.patch("requests.request", return_value=mock_response)

mock_blob_client = mocker.Mock()
mocker.patch("azure.storage.blob.BlobClient.from_blob_url", return_value=mock_blob_client)

response = driver.try_save_file("foo", b"value")
def test_try_save_file_asset_url_404(self, mocker, driver):
mock_create_response = mocker.Mock()
mock_create_response.raise_for_status.side_effect = requests.exceptions.HTTPError(
response=mock.Mock(status_code=404)
)
mocker.patch("requests.request", return_value=mock_create_response)

assert response == "buckets/1/assets/foo"
with pytest.raises(requests.exceptions.HTTPError):
driver.try_save_file("foo", b"value")

def test_try_save_file_sas_500(self, mocker, driver):
def test_try_save_file_asset_url_500(self, mocker, driver):
mocker.patch("requests.request", side_effect=requests.exceptions.HTTPError(response=mock.Mock(status_code=500)))

with pytest.raises(requests.exceptions.HTTPError):
Expand Down

0 comments on commit af5ca11

Please sign in to comment.