diff --git a/src/btrfs2s3/action.py b/src/btrfs2s3/action.py index 6931445..c533dc8 100644 --- a/src/btrfs2s3/action.py +++ b/src/btrfs2s3/action.py @@ -14,6 +14,7 @@ from btrfs2s3._internal.util import NULL_UUID from btrfs2s3._internal.util import SubvolumeFlags +from btrfs2s3.stream_uploader import upload_non_seekable_stream_via_tempfile from btrfs2s3.thunk import Thunk from btrfs2s3.thunk import ThunkArg @@ -26,6 +27,8 @@ _LOG = logging.getLogger(__name__) +DEFAULT_PART_SIZE = 5 * 2**30 + def create_snapshot(*, source: Path, path: Path) -> None: """Create a read-only snapshot of a subvolume. @@ -94,12 +97,16 @@ def create_backup( # noqa: PLR0913 send_parent: Path | None, key: str, pipe_through: Sequence[Sequence[str]] = (), + part_size: int = DEFAULT_PART_SIZE, ) -> None: """Stores a btrfs archive in S3. This will spawn "btrfs -q send" as a subprocess, as there is currently no way to create a btrfs-send stream via pure python. + This will temporarily store the "btrfs send" stream on disk. A maximum of + part_size bytes will be stored at a time. + Args: s3: An S3 client. bucket: The bucket in which to store the archive. @@ -109,6 +116,9 @@ def create_backup( # noqa: PLR0913 key: The S3 object key. pipe_through: A sequence of shell commands through which the archive should be piped before uploading. + part_size: For multipart uploads, use this as the maximum part size. + Defaults to the well-known maximum part size for AWS, which is + currently 5 GiB. """ _LOG.info( "creating backup of %s (%s)", @@ -132,7 +142,13 @@ def create_backup( # noqa: PLR0913 # https://github.com/python/typeshed/issues/3831 assert pipeline_stdout is not None # noqa: S101 try: - s3.upload_fileobj(pipeline_stdout, bucket, key) + upload_non_seekable_stream_via_tempfile( + client=s3, + bucket=bucket, + key=key, + stream=pipeline_stdout, + part_size=part_size, + ) finally: # Allow the pipeline to fail if the upload fails pipeline_stdout.close() diff --git a/src/btrfs2s3/stream_uploader.py b/src/btrfs2s3/stream_uploader.py new file mode 100644 index 0000000..913045b --- /dev/null +++ b/src/btrfs2s3/stream_uploader.py @@ -0,0 +1,127 @@ +"""Functions for uploading a non-seekable stream to S3.""" + +from __future__ import annotations + +import os +from tempfile import TemporaryFile +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from typing import AnyStr + from typing import IO + from typing import Iterator + + from mypy_boto3_s3.client import S3Client + from mypy_boto3_s3.type_defs import CompletedPartTypeDef + +_COPY_BUFFER_SIZE = 2**20 + + +def _copy(*, input_file: IO[bytes], output_file: IO[bytes], count: int) -> int: + # notes for fast copy: + # - buffered IO classes have seek position out of sync with underlying file + # descriptor, mitigate this with SEEK_END + # - sendfile and splice. which should be used first? + # - use readinto and reuse buffers for slow mode + written = 0 + eof = False + while written < count and not eof: + # This may result in one or multiple underlying reads + buf = input_file.read(min(count - written, _COPY_BUFFER_SIZE)) + if not buf: + eof = True + break + # https://docs.python.org/3/library/io.html#io.RawIOBase.write + # indicates this loop is needed + offset = 0 + while offset < len(buf): + offset += output_file.write(buf[offset:]) + written += len(buf) + return written + + +def _iter_parts_via_tempfile(stream: IO[bytes], part_size: int) -> Iterator[IO[bytes]]: + while True: + with TemporaryFile() as part_file: + written = _copy(input_file=stream, output_file=part_file, count=part_size) + if written > 0: + part_file.seek(0, os.SEEK_SET) + yield part_file + if written < part_size: + break + + +def _stream_len(stream: IO[AnyStr]) -> int: + cur = stream.tell() + end = stream.seek(0, os.SEEK_END) + stream.seek(cur, os.SEEK_SET) + return end + + +def upload_non_seekable_stream_via_tempfile( + *, stream: IO[bytes], client: S3Client, bucket: str, key: str, part_size: int +) -> None: + """Upload a non-seekable stream to S3. + + This will store the stream in parts to temporary files, of part_size bytes + each. If less than one full part is consumed from the stream, it will + upload the object with put_object. Otherwise, a multipart upload will be + used. + + If any error is raised, this function will attempt to cancel the multipart + upload with abort_multipart_upload(). + + Args: + stream: A stream to upload. The stream may be seekable, but this + function is designed for the non-seekable case. + client: The S3 client object. + bucket: The name of the S3 bucket. + key: The key of the S3 object in the bucket. + part_size: The maximum size of a single part. + """ + # If the first part is the maximum part size, assume there will be more parts. This + # is suboptimal in the rare case that the stream is exactly one part length long. + # The alternative is to attempt to read an extra byte from the stream after the + # first part has been collected, and append it to the next part. The 1-byte reads + # will frequently be unaligned and lead to cache thrashing. The optimal strategy + # would be: + # - Read the first full part + # - Read 1 test byte + # - Read the second part minus one byte + # - Read the remaining parts as normal + # This would be a lot of code complexity for a very rare gain. + upload_id: str | None = None + completed_parts: list[CompletedPartTypeDef] = [] + try: + for part_index, part_file in enumerate( + _iter_parts_via_tempfile(stream, part_size) + ): + if upload_id is None and _stream_len(part_file) == part_size: + upload_id = client.create_multipart_upload(Bucket=bucket, Key=key)[ + "UploadId" + ] + if upload_id is not None: + part_number = part_index + 1 + up_response = client.upload_part( + Bucket=bucket, + Key=key, + PartNumber=part_number, + UploadId=upload_id, + Body=part_file, + ) + completed_parts.append( + {"ETag": up_response["ETag"], "PartNumber": part_number} + ) + else: + client.put_object(Bucket=bucket, Key=key, Body=part_file) + if upload_id is not None: + client.complete_multipart_upload( + Bucket=bucket, + Key=key, + UploadId=upload_id, + MultipartUpload={"Parts": completed_parts}, + ) + except Exception: + if upload_id is not None: + client.abort_multipart_upload(Bucket=bucket, Key=key, UploadId=upload_id) + raise diff --git a/tests/stream_uploader/__init__.py b/tests/stream_uploader/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/stream_uploader/conftest.py b/tests/stream_uploader/conftest.py new file mode 100644 index 0000000..f891bac --- /dev/null +++ b/tests/stream_uploader/conftest.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import os +from tempfile import TemporaryFile +from threading import Thread +from typing import cast +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from typing import IO + from typing import Iterator + + +@pytest.fixture(params=[5 * 2**20, 5 * 2**20 + 512, 10 * 2**20]) +def data_size(request: pytest.FixtureRequest) -> int: + return cast(int, request.param) + + +@pytest.fixture(params=[5 * 2**20, 5 * 2**20 + 512, 10 * 2**20]) +def part_size(request: pytest.FixtureRequest) -> int: + return cast(int, request.param) + + +@pytest.fixture() +def stream_data(data_size: int) -> bytes: + return os.urandom(data_size) + + +@pytest.fixture(params=[-1, 0], ids=["buffered", "unbuffered"]) +def buffering(request: pytest.FixtureRequest) -> int: + return cast(int, request.param) + + +@pytest.fixture(params=[False, True], ids=["nonseekable", "seekable"]) +def seekable(request: pytest.FixtureRequest) -> bool: + return cast(bool, request.param) + + +@pytest.fixture() +def stream(buffering: int, seekable: bool, stream_data: bytes) -> Iterator[IO[bytes]]: # noqa: FBT001 + if seekable: + with TemporaryFile(buffering=buffering) as stream: + stream.write(stream_data) + stream.seek(0, os.SEEK_SET) + yield stream + else: + read_fd, write_fd = os.pipe() + + def fill(write_fd: int) -> None: + os.write(write_fd, stream_data) + os.close(write_fd) + + fill_thread = Thread(target=fill, args=(write_fd,)) + fill_thread.start() + with open(read_fd, mode="rb", buffering=buffering) as stream: # noqa: PTH123 + yield stream + fill_thread.join() diff --git a/tests/stream_uploader/iter_parts_test.py b/tests/stream_uploader/iter_parts_test.py new file mode 100644 index 0000000..2c39492 --- /dev/null +++ b/tests/stream_uploader/iter_parts_test.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from btrfs2s3.stream_uploader import _iter_parts_via_tempfile + +if TYPE_CHECKING: + from typing import IO + + +def test_iter_parts_via_tempfile( + stream: IO[bytes], stream_data: bytes, part_size: int +) -> None: + expected_data = [ + stream_data[i : i + part_size] for i in range(0, len(stream_data), part_size) + ] + + got_data = [] + for part_file in _iter_parts_via_tempfile(stream, part_size): + assert part_file.seekable() + assert part_file.tell() == 0 + got_data.append(part_file.read()) + + assert got_data == expected_data diff --git a/tests/stream_uploader/upload_non_seekable_stream_via_tempfile_test.py b/tests/stream_uploader/upload_non_seekable_stream_via_tempfile_test.py new file mode 100644 index 0000000..97ab9c3 --- /dev/null +++ b/tests/stream_uploader/upload_non_seekable_stream_via_tempfile_test.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest.mock import Mock + +from btrfs2s3.stream_uploader import upload_non_seekable_stream_via_tempfile +import pytest + +if TYPE_CHECKING: + from typing import IO + + from mypy_boto3_s3.client import S3Client + + +def test_file_is_uploaded( + s3: S3Client, stream: IO[bytes], bucket: str, part_size: int, stream_data: bytes +) -> None: + key = "test-key" + + upload_non_seekable_stream_via_tempfile( + client=s3, stream=stream, bucket=bucket, key=key, part_size=part_size + ) + + assert s3.get_object(Bucket=bucket, Key=key)["Body"].read() == stream_data + + +class FakeError(Exception): + pass + + +def test_multipart_upload_gets_cancelled_on_error( + s3: S3Client, stream: IO[bytes], bucket: str, part_size: int +) -> None: + key = "test-key" + + mock_client = Mock(wraps=s3) + mock_client.put_object.side_effect = FakeError() + mock_client.upload_part.side_effect = FakeError() + + with pytest.raises(FakeError): + upload_non_seekable_stream_via_tempfile( + client=mock_client, + stream=stream, + bucket=bucket, + key=key, + part_size=part_size, + ) + + assert s3.list_multipart_uploads(Bucket=bucket).get("Uploads", []) == []