-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
implement multipart upload via on-disk tempfiles
- Loading branch information
1 parent
47cfcdd
commit 2c3972c
Showing
6 changed files
with
276 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
"""Functions for uploading a non-seekable stream to S3.""" | ||
|
||
from __future__ import annotations | ||
|
||
import os | ||
from tempfile import TemporaryFile | ||
from typing import TYPE_CHECKING | ||
|
||
if TYPE_CHECKING: | ||
from typing import AnyStr | ||
from typing import IO | ||
from typing import Iterator | ||
|
||
from mypy_boto3_s3.client import S3Client | ||
from mypy_boto3_s3.type_defs import CompletedPartTypeDef | ||
|
||
_COPY_BUFFER_SIZE = 2**20 | ||
|
||
|
||
def _copy(*, input_file: IO[bytes], output_file: IO[bytes], count: int) -> int: | ||
# notes for fast copy: | ||
# - buffered IO classes have seek position out of sync with underlying file | ||
# descriptor, mitigate this with SEEK_END | ||
# - sendfile and splice. which should be used first? | ||
# - use readinto and reuse buffers for slow mode | ||
written = 0 | ||
eof = False | ||
while written < count and not eof: | ||
# This may result in one or multiple underlying reads | ||
buf = input_file.read(min(count - written, _COPY_BUFFER_SIZE)) | ||
if not buf: | ||
eof = True | ||
break | ||
# https://docs.python.org/3/library/io.html#io.RawIOBase.write | ||
# indicates this loop is needed | ||
offset = 0 | ||
while offset < len(buf): | ||
offset += output_file.write(buf[offset:]) | ||
written += len(buf) | ||
return written | ||
|
||
|
||
def _iter_parts_via_tempfile(stream: IO[bytes], part_size: int) -> Iterator[IO[bytes]]: | ||
while True: | ||
with TemporaryFile() as part_file: | ||
written = _copy(input_file=stream, output_file=part_file, count=part_size) | ||
if written > 0: | ||
part_file.seek(0, os.SEEK_SET) | ||
yield part_file | ||
if written < part_size: | ||
break | ||
|
||
|
||
def _stream_len(stream: IO[AnyStr]) -> int: | ||
cur = stream.tell() | ||
end = stream.seek(0, os.SEEK_END) | ||
stream.seek(cur, os.SEEK_SET) | ||
return end | ||
|
||
|
||
def upload_non_seekable_stream_via_tempfile( | ||
*, stream: IO[bytes], client: S3Client, bucket: str, key: str, part_size: int | ||
) -> None: | ||
"""Upload a non-seekable stream to S3. | ||
This will store the stream in parts to temporary files, of part_size bytes | ||
each. If less than one full part is consumed from the stream, it will | ||
upload the object with put_object. Otherwise, a multipart upload will be | ||
used. | ||
If any error is raised, this function will attempt to cancel the multipart | ||
upload with abort_multipart_upload(). | ||
Args: | ||
stream: A stream to upload. The stream may be seekable, but this | ||
function is designed for the non-seekable case. | ||
client: The S3 client object. | ||
bucket: The name of the S3 bucket. | ||
key: The key of the S3 object in the bucket. | ||
part_size: The maximum size of a single part. | ||
""" | ||
# If the first part is the maximum part size, assume there will be more parts. This | ||
# is suboptimal in the rare case that the stream is exactly one part length long. | ||
# The alternative is to attempt to read an extra byte from the stream after the | ||
# first part has been collected, and append it to the next part. The 1-byte reads | ||
# will frequently be unaligned and lead to cache thrashing. The optimal strategy | ||
# would be: | ||
# - Read the first full part | ||
# - Read 1 test byte | ||
# - Read the second part minus one byte | ||
# - Read the remaining parts as normal | ||
# This would be a lot of code complexity for a very rare gain. | ||
upload_id: str | None = None | ||
completed_parts: list[CompletedPartTypeDef] = [] | ||
try: | ||
for part_index, part_file in enumerate( | ||
_iter_parts_via_tempfile(stream, part_size) | ||
): | ||
if upload_id is None and _stream_len(part_file) == part_size: | ||
upload_id = client.create_multipart_upload(Bucket=bucket, Key=key)[ | ||
"UploadId" | ||
] | ||
if upload_id is not None: | ||
part_number = part_index + 1 | ||
up_response = client.upload_part( | ||
Bucket=bucket, | ||
Key=key, | ||
PartNumber=part_number, | ||
UploadId=upload_id, | ||
Body=part_file, | ||
) | ||
completed_parts.append( | ||
{"ETag": up_response["ETag"], "PartNumber": part_number} | ||
) | ||
else: | ||
client.put_object(Bucket=bucket, Key=key, Body=part_file) | ||
if upload_id is not None: | ||
client.complete_multipart_upload( | ||
Bucket=bucket, | ||
Key=key, | ||
UploadId=upload_id, | ||
MultipartUpload={"Parts": completed_parts}, | ||
) | ||
except Exception: | ||
if upload_id is not None: | ||
client.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) | ||
raise |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
from __future__ import annotations | ||
|
||
import os | ||
from tempfile import TemporaryFile | ||
from threading import Thread | ||
from typing import cast | ||
from typing import TYPE_CHECKING | ||
|
||
import pytest | ||
|
||
if TYPE_CHECKING: | ||
from typing import IO | ||
from typing import Iterator | ||
|
||
|
||
@pytest.fixture(params=[5 * 2**20, 5 * 2**20 + 512, 10 * 2**20]) | ||
def data_size(request: pytest.FixtureRequest) -> int: | ||
return cast(int, request.param) | ||
|
||
|
||
@pytest.fixture(params=[5 * 2**20, 5 * 2**20 + 512, 10 * 2**20]) | ||
def part_size(request: pytest.FixtureRequest) -> int: | ||
return cast(int, request.param) | ||
|
||
|
||
@pytest.fixture() | ||
def stream_data(data_size: int) -> bytes: | ||
return os.urandom(data_size) | ||
|
||
|
||
@pytest.fixture(params=[-1, 0], ids=["buffered", "unbuffered"]) | ||
def buffering(request: pytest.FixtureRequest) -> int: | ||
return cast(int, request.param) | ||
|
||
|
||
@pytest.fixture(params=[False, True], ids=["nonseekable", "seekable"]) | ||
def seekable(request: pytest.FixtureRequest) -> bool: | ||
return cast(bool, request.param) | ||
|
||
|
||
@pytest.fixture() | ||
def stream(buffering: int, seekable: bool, stream_data: bytes) -> Iterator[IO[bytes]]: # noqa: FBT001 | ||
if seekable: | ||
with TemporaryFile(buffering=buffering) as stream: | ||
stream.write(stream_data) | ||
stream.seek(0, os.SEEK_SET) | ||
yield stream | ||
else: | ||
read_fd, write_fd = os.pipe() | ||
|
||
def fill(write_fd: int) -> None: | ||
os.write(write_fd, stream_data) | ||
os.close(write_fd) | ||
|
||
fill_thread = Thread(target=fill, args=(write_fd,)) | ||
fill_thread.start() | ||
with open(read_fd, mode="rb", buffering=buffering) as stream: # noqa: PTH123 | ||
yield stream | ||
fill_thread.join() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
|
||
from btrfs2s3.stream_uploader import _iter_parts_via_tempfile | ||
|
||
if TYPE_CHECKING: | ||
from typing import IO | ||
|
||
|
||
def test_iter_parts_via_tempfile( | ||
stream: IO[bytes], stream_data: bytes, part_size: int | ||
) -> None: | ||
expected_data = [ | ||
stream_data[i : i + part_size] for i in range(0, len(stream_data), part_size) | ||
] | ||
|
||
got_data = [] | ||
for part_file in _iter_parts_via_tempfile(stream, part_size): | ||
assert part_file.seekable() | ||
assert part_file.tell() == 0 | ||
got_data.append(part_file.read()) | ||
|
||
assert got_data == expected_data |
49 changes: 49 additions & 0 deletions
49
tests/stream_uploader/upload_non_seekable_stream_via_tempfile_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
from __future__ import annotations | ||
|
||
from typing import TYPE_CHECKING | ||
from unittest.mock import Mock | ||
|
||
from btrfs2s3.stream_uploader import upload_non_seekable_stream_via_tempfile | ||
import pytest | ||
|
||
if TYPE_CHECKING: | ||
from typing import IO | ||
|
||
from mypy_boto3_s3.client import S3Client | ||
|
||
|
||
def test_file_is_uploaded( | ||
s3: S3Client, stream: IO[bytes], bucket: str, part_size: int, stream_data: bytes | ||
) -> None: | ||
key = "test-key" | ||
|
||
upload_non_seekable_stream_via_tempfile( | ||
client=s3, stream=stream, bucket=bucket, key=key, part_size=part_size | ||
) | ||
|
||
assert s3.get_object(Bucket=bucket, Key=key)["Body"].read() == stream_data | ||
|
||
|
||
class FakeError(Exception): | ||
pass | ||
|
||
|
||
def test_multipart_upload_gets_cancelled_on_error( | ||
s3: S3Client, stream: IO[bytes], bucket: str, part_size: int | ||
) -> None: | ||
key = "test-key" | ||
|
||
mock_client = Mock(wraps=s3) | ||
mock_client.put_object.side_effect = FakeError() | ||
mock_client.upload_part.side_effect = FakeError() | ||
|
||
with pytest.raises(FakeError): | ||
upload_non_seekable_stream_via_tempfile( | ||
client=mock_client, | ||
stream=stream, | ||
bucket=bucket, | ||
key=key, | ||
part_size=part_size, | ||
) | ||
|
||
assert s3.list_multipart_uploads(Bucket=bucket).get("Uploads", []) == [] |