Skip to content

Commit

Permalink
implement multipart upload via on-disk tempfiles
Browse files Browse the repository at this point in the history
  • Loading branch information
sbrudenell committed Aug 14, 2024
1 parent 47cfcdd commit 2c3972c
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 1 deletion.
18 changes: 17 additions & 1 deletion src/btrfs2s3/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from btrfs2s3._internal.util import NULL_UUID
from btrfs2s3._internal.util import SubvolumeFlags
from btrfs2s3.stream_uploader import upload_non_seekable_stream_via_tempfile
from btrfs2s3.thunk import Thunk
from btrfs2s3.thunk import ThunkArg

Expand All @@ -26,6 +27,8 @@

_LOG = logging.getLogger(__name__)

DEFAULT_PART_SIZE = 5 * 2**30


def create_snapshot(*, source: Path, path: Path) -> None:
"""Create a read-only snapshot of a subvolume.
Expand Down Expand Up @@ -94,12 +97,16 @@ def create_backup( # noqa: PLR0913
send_parent: Path | None,
key: str,
pipe_through: Sequence[Sequence[str]] = (),
part_size: int = DEFAULT_PART_SIZE,
) -> None:
"""Stores a btrfs archive in S3.
This will spawn "btrfs -q send" as a subprocess, as there is currently no way
to create a btrfs-send stream via pure python.
This will temporarily store the "btrfs send" stream on disk. A maximum of
part_size bytes will be stored at a time.
Args:
s3: An S3 client.
bucket: The bucket in which to store the archive.
Expand All @@ -109,6 +116,9 @@ def create_backup( # noqa: PLR0913
key: The S3 object key.
pipe_through: A sequence of shell commands through which the archive
should be piped before uploading.
part_size: For multipart uploads, use this as the maximum part size.
Defaults to the well-known maximum part size for AWS, which is
currently 5 GiB.
"""
_LOG.info(
"creating backup of %s (%s)",
Expand All @@ -132,7 +142,13 @@ def create_backup( # noqa: PLR0913
# https://github.com/python/typeshed/issues/3831
assert pipeline_stdout is not None # noqa: S101
try:
s3.upload_fileobj(pipeline_stdout, bucket, key)
upload_non_seekable_stream_via_tempfile(
client=s3,
bucket=bucket,
key=key,
stream=pipeline_stdout,
part_size=part_size,
)
finally:
# Allow the pipeline to fail if the upload fails
pipeline_stdout.close()
Expand Down
127 changes: 127 additions & 0 deletions src/btrfs2s3/stream_uploader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""Functions for uploading a non-seekable stream to S3."""

from __future__ import annotations

import os
from tempfile import TemporaryFile
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import AnyStr
from typing import IO
from typing import Iterator

from mypy_boto3_s3.client import S3Client
from mypy_boto3_s3.type_defs import CompletedPartTypeDef

_COPY_BUFFER_SIZE = 2**20


def _copy(*, input_file: IO[bytes], output_file: IO[bytes], count: int) -> int:
# notes for fast copy:
# - buffered IO classes have seek position out of sync with underlying file
# descriptor, mitigate this with SEEK_END
# - sendfile and splice. which should be used first?
# - use readinto and reuse buffers for slow mode
written = 0
eof = False
while written < count and not eof:
# This may result in one or multiple underlying reads
buf = input_file.read(min(count - written, _COPY_BUFFER_SIZE))
if not buf:
eof = True
break
# https://docs.python.org/3/library/io.html#io.RawIOBase.write
# indicates this loop is needed
offset = 0
while offset < len(buf):
offset += output_file.write(buf[offset:])
written += len(buf)
return written


def _iter_parts_via_tempfile(stream: IO[bytes], part_size: int) -> Iterator[IO[bytes]]:
while True:
with TemporaryFile() as part_file:
written = _copy(input_file=stream, output_file=part_file, count=part_size)
if written > 0:
part_file.seek(0, os.SEEK_SET)
yield part_file
if written < part_size:
break


def _stream_len(stream: IO[AnyStr]) -> int:
cur = stream.tell()
end = stream.seek(0, os.SEEK_END)
stream.seek(cur, os.SEEK_SET)
return end


def upload_non_seekable_stream_via_tempfile(
*, stream: IO[bytes], client: S3Client, bucket: str, key: str, part_size: int
) -> None:
"""Upload a non-seekable stream to S3.
This will store the stream in parts to temporary files, of part_size bytes
each. If less than one full part is consumed from the stream, it will
upload the object with put_object. Otherwise, a multipart upload will be
used.
If any error is raised, this function will attempt to cancel the multipart
upload with abort_multipart_upload().
Args:
stream: A stream to upload. The stream may be seekable, but this
function is designed for the non-seekable case.
client: The S3 client object.
bucket: The name of the S3 bucket.
key: The key of the S3 object in the bucket.
part_size: The maximum size of a single part.
"""
# If the first part is the maximum part size, assume there will be more parts. This
# is suboptimal in the rare case that the stream is exactly one part length long.
# The alternative is to attempt to read an extra byte from the stream after the
# first part has been collected, and append it to the next part. The 1-byte reads
# will frequently be unaligned and lead to cache thrashing. The optimal strategy
# would be:
# - Read the first full part
# - Read 1 test byte
# - Read the second part minus one byte
# - Read the remaining parts as normal
# This would be a lot of code complexity for a very rare gain.
upload_id: str | None = None
completed_parts: list[CompletedPartTypeDef] = []
try:
for part_index, part_file in enumerate(
_iter_parts_via_tempfile(stream, part_size)
):
if upload_id is None and _stream_len(part_file) == part_size:
upload_id = client.create_multipart_upload(Bucket=bucket, Key=key)[
"UploadId"
]
if upload_id is not None:
part_number = part_index + 1
up_response = client.upload_part(
Bucket=bucket,
Key=key,
PartNumber=part_number,
UploadId=upload_id,
Body=part_file,
)
completed_parts.append(
{"ETag": up_response["ETag"], "PartNumber": part_number}
)
else:
client.put_object(Bucket=bucket, Key=key, Body=part_file)
if upload_id is not None:
client.complete_multipart_upload(
Bucket=bucket,
Key=key,
UploadId=upload_id,
MultipartUpload={"Parts": completed_parts},
)
except Exception:
if upload_id is not None:
client.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id)
raise
Empty file.
59 changes: 59 additions & 0 deletions tests/stream_uploader/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from __future__ import annotations

import os
from tempfile import TemporaryFile
from threading import Thread
from typing import cast
from typing import TYPE_CHECKING

import pytest

if TYPE_CHECKING:
from typing import IO
from typing import Iterator


@pytest.fixture(params=[5 * 2**20, 5 * 2**20 + 512, 10 * 2**20])
def data_size(request: pytest.FixtureRequest) -> int:
return cast(int, request.param)


@pytest.fixture(params=[5 * 2**20, 5 * 2**20 + 512, 10 * 2**20])
def part_size(request: pytest.FixtureRequest) -> int:
return cast(int, request.param)


@pytest.fixture()
def stream_data(data_size: int) -> bytes:
return os.urandom(data_size)


@pytest.fixture(params=[-1, 0], ids=["buffered", "unbuffered"])
def buffering(request: pytest.FixtureRequest) -> int:
return cast(int, request.param)


@pytest.fixture(params=[False, True], ids=["nonseekable", "seekable"])
def seekable(request: pytest.FixtureRequest) -> bool:
return cast(bool, request.param)


@pytest.fixture()
def stream(buffering: int, seekable: bool, stream_data: bytes) -> Iterator[IO[bytes]]: # noqa: FBT001
if seekable:
with TemporaryFile(buffering=buffering) as stream:
stream.write(stream_data)
stream.seek(0, os.SEEK_SET)
yield stream
else:
read_fd, write_fd = os.pipe()

def fill(write_fd: int) -> None:
os.write(write_fd, stream_data)
os.close(write_fd)

fill_thread = Thread(target=fill, args=(write_fd,))
fill_thread.start()
with open(read_fd, mode="rb", buffering=buffering) as stream: # noqa: PTH123
yield stream
fill_thread.join()
24 changes: 24 additions & 0 deletions tests/stream_uploader/iter_parts_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import TYPE_CHECKING

from btrfs2s3.stream_uploader import _iter_parts_via_tempfile

if TYPE_CHECKING:
from typing import IO


def test_iter_parts_via_tempfile(
stream: IO[bytes], stream_data: bytes, part_size: int
) -> None:
expected_data = [
stream_data[i : i + part_size] for i in range(0, len(stream_data), part_size)
]

got_data = []
for part_file in _iter_parts_via_tempfile(stream, part_size):
assert part_file.seekable()
assert part_file.tell() == 0
got_data.append(part_file.read())

assert got_data == expected_data
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from __future__ import annotations

from typing import TYPE_CHECKING
from unittest.mock import Mock

from btrfs2s3.stream_uploader import upload_non_seekable_stream_via_tempfile
import pytest

if TYPE_CHECKING:
from typing import IO

from mypy_boto3_s3.client import S3Client


def test_file_is_uploaded(
s3: S3Client, stream: IO[bytes], bucket: str, part_size: int, stream_data: bytes
) -> None:
key = "test-key"

upload_non_seekable_stream_via_tempfile(
client=s3, stream=stream, bucket=bucket, key=key, part_size=part_size
)

assert s3.get_object(Bucket=bucket, Key=key)["Body"].read() == stream_data


class FakeError(Exception):
pass


def test_multipart_upload_gets_cancelled_on_error(
s3: S3Client, stream: IO[bytes], bucket: str, part_size: int
) -> None:
key = "test-key"

mock_client = Mock(wraps=s3)
mock_client.put_object.side_effect = FakeError()
mock_client.upload_part.side_effect = FakeError()

with pytest.raises(FakeError):
upload_non_seekable_stream_via_tempfile(
client=mock_client,
stream=stream,
bucket=bucket,
key=key,
part_size=part_size,
)

assert s3.list_multipart_uploads(Bucket=bucket).get("Uploads", []) == []

0 comments on commit 2c3972c

Please sign in to comment.