Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Upload Endpoints #565

Merged
merged 18 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
cefcc23
Refactor: Multiplart MultipartUploadedFile will read with chunck to a…
1yam Apr 25, 2024
885c59c
Feature: Upload storage endpoint using tempfile to reduce memory usage
1yam May 2, 2024
e62da4b
Refacto: using request.read() when content type is not multipart/form…
1yam May 15, 2024
ff83b6d
Fix: mypy issue
1yam May 15, 2024
68601a0
Refactor: Upload endpoints using multipart instead of post and raw up…
1yam May 28, 2024
372c839
Feature: Unit test for raw upload
1yam May 28, 2024
b9e39b6
Fix: mypy error
1yam May 28, 2024
c87263e
Fix: last mypy erro
1yam May 28, 2024
e77361a
Refactor: Uploaded_file & storage_upload
1yam May 30, 2024
7437506
Refactor: remove the context manager from UploadedFile class to just …
1yam May 30, 2024
f8d4b3f
refactor(storage): avoid using private variables when not needed
Psycojoker Jun 5, 2024
beda20b
feat(storage): make UploadedFile.cleanup callable multiple times
Psycojoker Jun 5, 2024
f7fc431
doc(storage): add comment regarding NamedTemporaryFile API change in …
Psycojoker Jun 5, 2024
3d49f4c
refactor(storage): rename inner function variable for better readability
Psycojoker Jun 5, 2024
1049a66
feat(storage): make web.HTTPUnprocessableEntity exception message mor…
Psycojoker Jun 5, 2024
e69d550
refactor(storage): make if more pythonic
Psycojoker Jun 5, 2024
110c4ee
fix(storage): ensure that uploaded temporary file is **always** cleanup
Psycojoker Jun 5, 2024
526de08
fix(storage): function argument name has changed
Psycojoker Jun 5, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ install_requires =
urllib3==2.0.7
uvloop==0.19.0
web3==6.11.2
aiofiles==23.2.1
types-aiofiles==23.2.0.20240403

dependency_links =
https://github.com/aleph-im/py-libp2p/tarball/0.1.4-1-use-set#egg=libp2p
Expand Down
285 changes: 174 additions & 111 deletions src/aleph/web/controllers/storage.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import base64
import hashlib
import logging
import os
import tempfile
from decimal import Decimal
from typing import Optional, Protocol, Union
from typing import Optional

import aio_pika
import aiofiles
import pydantic
from aiohttp import web
from aiohttp import BodyPartReader, web
from aiohttp.web_request import FileField
from aleph.chains.signature_verifier import SignatureVerifier
from aleph.db.accessors.balances import get_total_balance
Expand Down Expand Up @@ -92,7 +96,7 @@ async def add_storage_json_controller(request: web.Request):


async def _verify_message_signature(
pending_message: BasePendingMessage, signature_verifier: SignatureVerifier
pending_message: BasePendingMessage, signature_verifier: SignatureVerifier
) -> None:
try:
await signature_verifier.verify_signature(pending_message)
Expand All @@ -114,59 +118,113 @@ class StorageMetadata(pydantic.BaseModel):
sync: bool


class UploadedFile(Protocol):
@property
def size(self) -> int:
...

@property
def content(self) -> Union[str, bytes]:
...


class MultipartUploadedFile:
_content: Optional[bytes]
size: int

def __init__(self, file_field: FileField, size: int):
class UploadedFile:
def __init__(self, max_size: int):
self.max_size = max_size
self.hash = ""
self.size = 0
self._hasher = hashlib.sha256()
self._temp_file_path = None
self._temp_file = None

async def open_temp_file(self):
if not self._temp_file_path:
raise ValueError("File content has not been validated and read yet.")
self._temp_file = await aiofiles.open(self._temp_file_path, 'rb')
return self._temp_file

async def close_temp_file(self):
if self._temp_file is not None:
await self._temp_file.close()
self._temp_file = None

async def cleanup(self):
await self.close_temp_file()
if self._temp_file_path and os.path.exists(self._temp_file_path):
os.remove(self._temp_file_path)
self._temp_file_path = None

async def read_and_validate(self):
total_read = 0
chunk_size = 8192

# From aiofiles changelog:
# On Python 3.12, aiofiles.tempfile.NamedTemporaryFile now accepts a
# delete_on_close argument, just like the stdlib version.
# On Python 3.12, aiofiles.tempfile.NamedTemporaryFile no longer
# exposes a delete attribute, just like the stdlib version.
#
# so we might need to modify this code for python 3.12 at some point

# it would be ideal to uses aiofiles.tempfile.NamedTemporaryFile but it
# doesn't seems to be able to support our current workflow
temp_file = tempfile.NamedTemporaryFile('w+b', delete=False)
self._temp_file_path = temp_file.name
temp_file.close()

async with aiofiles.open(self._temp_file_path, 'w+b') as f:
async for chunk in self._read_chunks(chunk_size):
total_read += len(chunk)
if total_read > self.max_size:
raise web.HTTPRequestEntityTooLarge(
reason="File size exceeds the maximum limit.",
max_size=self.max_size,
actual_size=total_read,
)
self._hasher.update(chunk) # Update file hash while reading the file
await f.write(chunk)

self.hash = self._hasher.hexdigest()
self.size = total_read
await f.seek(0)

async def _read_chunks(self, chunk_size):
raise NotImplementedError("Subclasses must implement this method")

def get_hash(self) -> str:
return self._hasher.hexdigest()


class MultipartUploadedFile(UploadedFile):
def __init__(self, file_field: BodyPartReader, max_size: int):
super().__init__(max_size)
self.file_field = file_field
self.size = size
self._content = None

@property
def content(self) -> bytes:
# Only read the stream once
if self._content is None:
self.file_field.file.seek(0)
self._content = self.file_field.file.read(self.size)
async def _read_chunks(self, chunk_size):
async for chunk in self.file_field.__aiter__():
yield chunk

return self._content

class RawUploadedFile(UploadedFile):
def __init__(self, request: web.Request, max_size: int):
super().__init__(max_size)
self.request = request

class RawUploadedFile:
def __init__(self, content: Union[bytes, str]):
self.content = content

@property
def size(self) -> int:
return len(self.content)
async def _read_chunks(self, chunk_size):
async for chunk in self.request.content.iter_chunked(chunk_size):
yield chunk


async def _check_and_add_file(
session: DbSession,
signature_verifier: SignatureVerifier,
storage_service: StorageService,
message: Optional[PendingStoreMessage],
file: UploadedFile,
grace_period: int,
session: DbSession,
signature_verifier: SignatureVerifier,
storage_service: StorageService,
message: Optional[PendingStoreMessage],
uploaded_file: UploadedFile,
grace_period: int,
) -> str:
file_hash = uploaded_file.get_hash()
# Perform authentication and balance checks
if message:
await _verify_message_signature(
pending_message=message, signature_verifier=signature_verifier
)
try:
message_content = StoreContent.parse_raw(message.item_content)
if message_content.item_hash != file_hash:
raise web.HTTPUnprocessableEntity(
reason=f"File hash does not match ({file_hash} != {message_content.item_hash})"
)
except ValidationError as e:
raise web.HTTPUnprocessableEntity(
reason=f"Invalid store message content: {e.json()}"
Expand All @@ -175,44 +233,40 @@ async def _check_and_add_file(
await _verify_user_balance(
session=session,
address=message_content.address,
size=file.size,
size=uploaded_file.size,
)

else:
message_content = None

# TODO: this can still reach 1 GiB in memory. We should look into streaming.
file_content = file.content
file_bytes = (
file_content.encode("utf-8") if isinstance(file_content, str) else file_content
)
file_hash = get_sha256(file_content)
temp_file = await uploaded_file.open_temp_file()
file_content = await temp_file.read()

if message_content:
if message_content.item_hash != file_hash:
raise web.HTTPUnprocessableEntity(
reason=f"File hash does not match ({file_hash} != {message_content.item_hash})"
)
if isinstance(file_content, bytes):
file_bytes = file_content
elif isinstance(file_content, str):
file_bytes = file_content.encode("utf-8")
else:
raise web.HTTPUnprocessableEntity(reason=f"Invalid file content type, got {type(file_content)}")

await storage_service.add_file_content_to_local_storage(
session=session,
file_content=file_bytes,
file_hash=file_hash,
file_hash=file_hash
)
await uploaded_file.cleanup()

# For files uploaded without authenticated upload, add a grace period of 1 day.
if not message_content:
if message_content is None:
add_grace_period_for_file(
session=session, file_hash=file_hash, hours=grace_period
)

return file_hash


async def _make_mq_queue(
request: web.Request,
sync: bool,
routing_key: Optional[str] = None,
request: web.Request,
sync: bool,
routing_key: Optional[str] = None,
) -> Optional[aio_pika.abc.AbstractQueue]:
if not sync:
return None
Expand All @@ -230,67 +284,76 @@ async def storage_add_file(request: web.Request):
signature_verifier = get_signature_verifier_from_request(request)
config = get_config_from_request(request)
grace_period = config.storage.grace_period.value
metadata = None
uploaded_file: Optional[UploadedFile] = None

post = await request.post()
try:
file_field = post["file"]
except KeyError:
raise web.HTTPUnprocessableEntity(reason="Missing 'file' in multipart form.")

if isinstance(file_field, FileField):
uploaded_file: UploadedFile = MultipartUploadedFile(file_field, len(file_field.file.read()))
else:
uploaded_file = RawUploadedFile(file_field)

metadata = post.get("metadata")

status_code = 200

if metadata:
metadata_bytes = (
metadata.file.read() if isinstance(metadata, FileField) else metadata
if request.content_type == "multipart/form-data":
reader = await request.multipart()
async for part in reader:
if part.name == 'file':
uploaded_file = MultipartUploadedFile(part, MAX_FILE_SIZE)
await uploaded_file.read_and_validate()
elif part.name == 'metadata':
metadata = await part.read(decode=True)
else:
uploaded_file = RawUploadedFile(request=request, max_size=MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE)
await uploaded_file.read_and_validate()

if uploaded_file is None:
raise web.HTTPBadRequest(reason="File should be sent as FormData or Raw Upload")

max_upload_size = (
MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE if not metadata else MAX_FILE_SIZE
)
try:
storage_metadata = StorageMetadata.parse_raw(metadata_bytes)
except ValidationError as e:
raise web.HTTPUnprocessableEntity(
reason=f"Could not decode metadata: {e.json()}"
if uploaded_file.size > max_upload_size:
raise web.HTTPRequestEntityTooLarge(
actual_size=uploaded_file.size, max_size=max_upload_size
)

message = storage_metadata.message
sync = storage_metadata.sync
max_upload_size = MAX_UPLOAD_FILE_SIZE

else:
# User did not provide a message in the `metadata` field
message = None
sync = False
max_upload_size = MAX_UNAUTHENTICATED_UPLOAD_FILE_SIZE
uploaded_file.max_size = max_upload_size

if uploaded_file.size > max_upload_size:
raise web.HTTPRequestEntityTooLarge(
actual_size=uploaded_file.size, max_size=MAX_UPLOAD_FILE_SIZE
)
status_code = 200

with session_factory() as session:
file_hash = await _check_and_add_file(
session=session,
signature_verifier=signature_verifier,
storage_service=storage_service,
message=message,
file=uploaded_file,
grace_period=grace_period,
)
session.commit()
if metadata:
metadata_bytes = (
metadata.file.read() if isinstance(metadata, FileField) else metadata
)
try:
storage_metadata = StorageMetadata.parse_raw(metadata_bytes)
except ValidationError as e:
raise web.HTTPUnprocessableEntity(
reason=f"Could not decode metadata: {e.json()}"
)

message = storage_metadata.message
sync = storage_metadata.sync
else:
message = None
sync = False

with session_factory() as session:
file_hash = await _check_and_add_file(
session=session,
signature_verifier=signature_verifier,
storage_service=storage_service,
message=message,
uploaded_file=uploaded_file,
grace_period=grace_period,
)
session.commit()
if message:
broadcast_status = await broadcast_and_process_message(
pending_message=message, sync=sync, request=request, logger=logger
)
status_code = broadcast_status_to_http_status(broadcast_status)

if message:
broadcast_status = await broadcast_and_process_message(
pending_message=message, sync=sync, request=request, logger=logger
)
status_code = broadcast_status_to_http_status(broadcast_status)
output = {"status": "success", "hash": file_hash}
return web.json_response(data=output, status=status_code)

output = {"status": "success", "hash": file_hash}
return web.json_response(data=output, status=status_code)
finally:
if uploaded_file is not None:
await uploaded_file.cleanup()


def assert_file_is_downloadable(session: DbSession, file_hash: str) -> None:
Expand Down
Loading
Loading