Skip to content

Commit

Permalink
Factor storage clients into separate modules. Closes Chainlit#1355. (C…
Browse files Browse the repository at this point in the history
…hainlit#1363)

* Refactor storage_clients.py from a script to a module inside chainlit.data
* Add pytest on S3StorageClient + add moto in tests dependencies
* Moved BaseStorageClient from from chainlit.data.base.py to chainlit.data.storage_clients.base.py

Co-authored-by: jl1andricca <[email protected]>
  • Loading branch information
ndricca and jl1andricca authored Oct 14, 2024
1 parent 189b3e4 commit e1e206b
Show file tree
Hide file tree
Showing 13 changed files with 287 additions and 53 deletions.
14 changes: 0 additions & 14 deletions backend/chainlit/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,17 +105,3 @@ async def update_thread(
@abstractmethod
async def build_debug_url(self) -> str:
pass


class BaseStorageClient(ABC):
"""Base class for non-text data persistence like Azure Data Lake, S3, Google Storage, etc."""

@abstractmethod
async def upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> Dict[str, Any]:
pass
3 changes: 2 additions & 1 deletion backend/chainlit/data/dynamodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
import boto3 # type: ignore
from boto3.dynamodb.types import TypeDeserializer, TypeSerializer
from chainlit.context import context
from chainlit.data.base import BaseDataLayer, BaseStorageClient
from chainlit.data.base import BaseDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.utils import queue_until_user_message
from chainlit.element import ElementDict
from chainlit.logger import logger
Expand Down
3 changes: 2 additions & 1 deletion backend/chainlit/data/sql_alchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import aiofiles
import aiohttp

from chainlit.data.base import BaseDataLayer, BaseStorageClient
from chainlit.data.base import BaseDataLayer
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.data.utils import queue_until_user_message
from chainlit.element import ElementDict
from chainlit.logger import logger
Expand Down
Empty file.
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

import boto3 # type: ignore
from azure.storage.filedatalake import (
ContentSettings,
DataLakeFileClient,
DataLakeServiceClient,
FileSystemClient,
)
from chainlit.data.base import BaseStorageClient
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.logger import logger

if TYPE_CHECKING:
Expand Down Expand Up @@ -79,34 +78,3 @@ async def upload_file(
except Exception as e:
logger.warn(f"AzureStorageClient, upload_file error: {e}")
return {}


class S3StorageClient(BaseStorageClient):
"""
Class to enable Amazon S3 storage provider
"""

def __init__(self, bucket: str):
try:
self.bucket = bucket
self.client = boto3.client("s3")
logger.info("S3StorageClient initialized")
except Exception as e:
logger.warn(f"S3StorageClient initialization error: {e}")

async def upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> Dict[str, Any]:
try:
self.client.put_object(
Bucket=self.bucket, Key=object_key, Body=data, ContentType=mime
)
url = f"https://{self.bucket}.s3.amazonaws.com/{object_key}"
return {"object_key": object_key, "url": url}
except Exception as e:
logger.warn(f"S3StorageClient, upload_file error: {e}")
return {}
16 changes: 16 additions & 0 deletions backend/chainlit/data/storage_clients/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Union


class BaseStorageClient(ABC):
"""Base class for non-text data persistence like Azure Data Lake, S3, Google Storage, etc."""

@abstractmethod
async def upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> Dict[str, Any]:
pass
36 changes: 36 additions & 0 deletions backend/chainlit/data/storage_clients/s3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from typing import Any, Dict, Union

import boto3 # type: ignore
from chainlit.data.storage_clients.base import BaseStorageClient
from chainlit.logger import logger


class S3StorageClient(BaseStorageClient):
"""
Class to enable Amazon S3 storage provider
"""

def __init__(self, bucket: str):
try:
self.bucket = bucket
self.client = boto3.client("s3")
logger.info("S3StorageClient initialized")
except Exception as e:
logger.warn(f"S3StorageClient initialization error: {e}")

async def upload_file(
self,
object_key: str,
data: Union[bytes, str],
mime: str = "application/octet-stream",
overwrite: bool = True,
) -> Dict[str, Any]:
try:
self.client.put_object(
Bucket=self.bucket, Key=object_key, Body=data, ContentType=mime
)
url = f"https://{self.bucket}.s3.amazonaws.com/{object_key}"
return {"object_key": object_key, "url": url}
except Exception as e:
logger.warn(f"S3StorageClient, upload_file error: {e}")
return {}
Loading

0 comments on commit e1e206b

Please sign in to comment.