diff --git a/setup.cfg b/setup.cfg index 1b37e0054..933510ef5 100644 --- a/setup.cfg +++ b/setup.cfg @@ -77,6 +77,8 @@ install_requires = urllib3==2.0.7 uvloop==0.19.0 web3==6.11.2 + aiofiles==23.2.1 + types-aiofiles==23.2.0.20240403 dependency_links = https://github.com/aleph-im/py-libp2p/tarball/0.1.4-1-use-set#egg=libp2p diff --git a/src/aleph/web/controllers/storage.py b/src/aleph/web/controllers/storage.py index 341f4dc16..ac90d8ca4 100644 --- a/src/aleph/web/controllers/storage.py +++ b/src/aleph/web/controllers/storage.py @@ -1,11 +1,15 @@ import base64 +import hashlib import logging +import os +import tempfile from decimal import Decimal -from typing import Optional, Protocol, Union +from typing import Optional import aio_pika +import aiofiles import pydantic -from aiohttp import web +from aiohttp import BodyPartReader, web from aiohttp.web_request import FileField from aleph.chains.signature_verifier import SignatureVerifier from aleph.db.accessors.balances import get_total_balance @@ -92,7 +96,7 @@ async def add_storage_json_controller(request: web.Request): async def _verify_message_signature( - pending_message: BasePendingMessage, signature_verifier: SignatureVerifier + pending_message: BasePendingMessage, signature_verifier: SignatureVerifier ) -> None: try: await signature_verifier.verify_signature(pending_message) @@ -114,52 +118,102 @@ class StorageMetadata(pydantic.BaseModel): sync: bool -class UploadedFile(Protocol): - @property - def size(self) -> int: - ... - - @property - def content(self) -> Union[str, bytes]: - ... - - -class MultipartUploadedFile: - _content: Optional[bytes] - size: int - - def __init__(self, file_field: FileField, size: int): +class UploadedFile: + def __init__(self, max_size: int): + self.max_size = max_size + self.hash = "" + self.size = 0 + self._hasher = hashlib.sha256() + self._temp_file_path = None + self._temp_file = None + + async def open_temp_file(self): + if not self._temp_file_path: + raise ValueError("File content has not been validated and read yet.") + self._temp_file = await aiofiles.open(self._temp_file_path, 'rb') + return self._temp_file + + async def close_temp_file(self): + if self._temp_file is not None: + await self._temp_file.close() + self._temp_file = None + + async def cleanup(self): + await self.close_temp_file() + if self._temp_file_path and os.path.exists(self._temp_file_path): + os.remove(self._temp_file_path) + self._temp_file_path = None + + async def read_and_validate(self): + total_read = 0 + chunk_size = 8192 + + # From aiofiles changelog: + # On Python 3.12, aiofiles.tempfile.NamedTemporaryFile now accepts a + # delete_on_close argument, just like the stdlib version. + # On Python 3.12, aiofiles.tempfile.NamedTemporaryFile no longer + # exposes a delete attribute, just like the stdlib version. + # + # so we might need to modify this code for python 3.12 at some point + + # it would be ideal to uses aiofiles.tempfile.NamedTemporaryFile but it + # doesn't seems to be able to support our current workflow + temp_file = tempfile.NamedTemporaryFile('w+b', delete=False) + self._temp_file_path = temp_file.name + temp_file.close() + + async with aiofiles.open(self._temp_file_path, 'w+b') as f: + async for chunk in self._read_chunks(chunk_size): + total_read += len(chunk) + if total_read > self.max_size: + raise web.HTTPRequestEntityTooLarge( + reason="File size exceeds the maximum limit.", + max_size=self.max_size, + actual_size=total_read, + ) + self._hasher.update(chunk) # Update file hash while reading the file + await f.write(chunk) + + self.hash = self._hasher.hexdigest() + self.size = total_read + await f.seek(0) + + async def _read_chunks(self, chunk_size): + raise NotImplementedError("Subclasses must implement this method") + + def get_hash(self) -> str: + return self._hasher.hexdigest() + + +class MultipartUploadedFile(UploadedFile): + def __init__(self, file_field: BodyPartReader, max_size: int): + super().__init__(max_size) self.file_field = file_field - self.size = size - self._content = None - @property - def content(self) -> bytes: - # Only read the stream once - if self._content is None: - self.file_field.file.seek(0) - self._content = self.file_field.file.read(self.size) + async def _read_chunks(self, chunk_size): + async for chunk in self.file_field.__aiter__(): + yield chunk - return self._content +class RawUploadedFile(UploadedFile): + def __init__(self, request: web.Request, max_size: int): + super().__init__(max_size) + self.request = request -class RawUploadedFile: - def __init__(self, content: Union[bytes, str]): - self.content = content - - @property - def size(self) -> int: - return len(self.content) + async def _read_chunks(self, chunk_size): + async for chunk in self.request.content.iter_chunked(chunk_size): + yield chunk async def _check_and_add_file( - session: DbSession, - signature_verifier: SignatureVerifier, - storage_service: StorageService, - message: Optional[PendingStoreMessage], - file: UploadedFile, - grace_period: int, + session: DbSession, + signature_verifier: SignatureVerifier, + storage_service: StorageService, + message: Optional[PendingStoreMessage], + uploaded_file: UploadedFile, + grace_period: int, ) -> str: + file_hash = uploaded_file.get_hash() # Perform authentication and balance checks if message: await _verify_message_signature( @@ -167,6 +221,10 @@ async def _check_and_add_file( ) try: message_content = StoreContent.parse_raw(message.item_content) + if message_content.item_hash != file_hash: + raise web.HTTPUnprocessableEntity( + reason=f"File hash does not match ({file_hash} != {message_content.item_hash})" + ) except ValidationError as e: raise web.HTTPUnprocessableEntity( reason=f"Invalid store message content: {e.json()}" @@ -175,44 +233,40 @@ async def _check_and_add_file( await _verify_user_balance( session=session, address=message_content.address, - size=file.size, + size=uploaded_file.size, ) - else: message_content = None - # TODO: this can still reach 1 GiB in memory. We should look into streaming. - file_content = file.content - file_bytes = ( - file_content.encode("utf-8") if isinstance(file_content, str) else file_content - ) - file_hash = get_sha256(file_content) + temp_file = await uploaded_file.open_temp_file() + file_content = await temp_file.read() - if message_content: - if message_content.item_hash != file_hash: - raise web.HTTPUnprocessableEntity( - reason=f"File hash does not match ({file_hash} != {message_content.item_hash})" - ) + if isinstance(file_content, bytes): + file_bytes = file_content + elif isinstance(file_content, str): + file_bytes = file_content.encode("utf-8") + else: + raise web.HTTPUnprocessableEntity(reason=f"Invalid file content type, got {type(file_content)}") await storage_service.add_file_content_to_local_storage( session=session, file_content=file_bytes, - file_hash=file_hash, + file_hash=file_hash ) + await uploaded_file.cleanup() # For files uploaded without authenticated upload, add a grace period of 1 day. - if not message_content: + if message_content is None: add_grace_period_for_file( session=session, file_hash=file_hash, hours=grace_period ) - return file_hash async def _make_mq_queue( - request: web.Request, - sync: bool, - routing_key: Optional[str] = None, + request: web.Request, + sync: bool, + routing_key: Optional[str] = None, ) -> Optional[aio_pika.abc.AbstractQueue]: if not sync: return None @@ -230,67 +284,76 @@ async def storage_add_file(request: web.Request): signature_verifier = get_signature_verifier_from_request(request) config = get_config_from_request(request) grace_period = config.storage.grace_period.value + metadata = None + uploaded_file: Optional[UploadedFile] = None - post = await request.post() try: - file_field = post["file"] - except KeyError: - raise web.HTTPUnprocessableEntity(reason="Missing 'file' in multipart form.") - - if isinstance(file_field, FileField): - uploaded_file: UploadedFile = MultipartUploadedFile(file_field, len(file_field.file.read())) - else: - uploaded_file = RawUploadedFile(file_field) - - metadata = post.get("metadata") - - status_code = 200 - - if metadata: - metadata_bytes = ( - metadata.file.read() if isinstance(metadata, FileField) else metadata + if request.content_type == "multipart/form-data": + reader = await request.multipart() + async for part in reader: + if part.name == 'file': + uploaded_file = MultipartUploadedFile(part, MAX_FILE_SIZE) + await uploaded_file.read_and_validate() + elif part.name == 'metadata': + metadata = await part.read(decode=True) + else: + uploaded_file = RawUploadedFile(request=request, max_size=MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE) + await uploaded_file.read_and_validate() + + if uploaded_file is None: + raise web.HTTPBadRequest(reason="File should be sent as FormData or Raw Upload") + + max_upload_size = ( + MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE if not metadata else MAX_FILE_SIZE ) - try: - storage_metadata = StorageMetadata.parse_raw(metadata_bytes) - except ValidationError as e: - raise web.HTTPUnprocessableEntity( - reason=f"Could not decode metadata: {e.json()}" + if uploaded_file.size > max_upload_size: + raise web.HTTPRequestEntityTooLarge( + actual_size=uploaded_file.size, max_size=max_upload_size ) - message = storage_metadata.message - sync = storage_metadata.sync - max_upload_size = MAX_UPLOAD_FILE_SIZE - - else: - # User did not provide a message in the `metadata` field - message = None - sync = False - max_upload_size = MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE + uploaded_file.max_size = max_upload_size - if uploaded_file.size > max_upload_size: - raise web.HTTPRequestEntityTooLarge( - actual_size=uploaded_file.size, max_size=MAX_UPLOAD_FILE_SIZE - ) + status_code = 200 - with session_factory() as session: - file_hash = await _check_and_add_file( - session=session, - signature_verifier=signature_verifier, - storage_service=storage_service, - message=message, - file=uploaded_file, - grace_period=grace_period, - ) - session.commit() + if metadata: + metadata_bytes = ( + metadata.file.read() if isinstance(metadata, FileField) else metadata + ) + try: + storage_metadata = StorageMetadata.parse_raw(metadata_bytes) + except ValidationError as e: + raise web.HTTPUnprocessableEntity( + reason=f"Could not decode metadata: {e.json()}" + ) + + message = storage_metadata.message + sync = storage_metadata.sync + else: + message = None + sync = False + + with session_factory() as session: + file_hash = await _check_and_add_file( + session=session, + signature_verifier=signature_verifier, + storage_service=storage_service, + message=message, + uploaded_file=uploaded_file, + grace_period=grace_period, + ) + session.commit() + if message: + broadcast_status = await broadcast_and_process_message( + pending_message=message, sync=sync, request=request, logger=logger + ) + status_code = broadcast_status_to_http_status(broadcast_status) - if message: - broadcast_status = await broadcast_and_process_message( - pending_message=message, sync=sync, request=request, logger=logger - ) - status_code = broadcast_status_to_http_status(broadcast_status) + output = {"status": "success", "hash": file_hash} + return web.json_response(data=output, status=status_code) - output = {"status": "success", "hash": file_hash} - return web.json_response(data=output, status=status_code) + finally: + if uploaded_file is not None: + await uploaded_file.cleanup() def assert_file_is_downloadable(session: DbSession, file_hash: str) -> None: diff --git a/tests/api/test_storage.py b/tests/api/test_storage.py index 382eb2ca2..7fc383dd3 100644 --- a/tests/api/test_storage.py +++ b/tests/api/test_storage.py @@ -92,6 +92,41 @@ async def api_client(ccn_test_aiohttp_app, mocker, aiohttp_client): return client +async def add_file_raw_upload( + api_client, + session_factory: DbSessionFactory, + uri: str, + file_content: bytes, + expected_file_hash: str, +): + # Send the file content as raw bytes in the request body + headers = { + 'Content-Type': 'application/octet-stream' + } + post_response = await api_client.post(uri, data=file_content, headers=headers) + response_text = await post_response.text() + assert post_response.status == 200, response_text + post_response_json = await post_response.json() + assert post_response_json["status"] == "success" + file_hash = post_response_json["hash"] + assert file_hash == expected_file_hash + + # Assert that the file is downloadable + get_file_response = await api_client.get(f"{GET_STORAGE_RAW_URI}/{file_hash}") + assert get_file_response.status == 200, await get_file_response.text() + response_data = await get_file_response.read() + + # Check that the file appears in the DB + with session_factory() as session: + file = get_file(session=session, file_hash=file_hash) + assert file is not None + assert file.hash == file_hash + assert file.type == FileType.FILE + assert file.size == len(file_content) + + assert response_data == file_content + + async def add_file( api_client, session_factory: DbSessionFactory, @@ -220,6 +255,16 @@ async def test_storage_add_file(api_client, session_factory: DbSessionFactory): expected_file_hash=EXPECTED_FILE_SHA256, ) +@pytest.mark.asyncio +async def test_storage_add_file_raw_upload(api_client, session_factory: DbSessionFactory): + await add_file_raw_upload( + api_client, + session_factory, + uri=STORAGE_ADD_FILE_URI, + file_content=FILE_CONTENT, + expected_file_hash=EXPECTED_FILE_SHA256, + ) + @pytest.mark.parametrize( "file_content, expected_hash, size, error_code, balance", @@ -371,3 +416,4 @@ async def test_ipfs_add_json(api_client, session_factory: DbSessionFactory): # creating a second fixture. expected_file_hash=ItemHash(EXPECTED_FILE_CID), ) +