From 8aa85d2ece88e148eb63546acaa3212481a6fcd8 Mon Sep 17 00:00:00 2001 From: kyleknap Date: Mon, 30 Oct 2023 16:52:13 -0700 Subject: [PATCH] Add support for stdin/stdout streams for CRT client This also included: * Refactoring the s3transfer crt layer to better organize the branching introduced in supporting both file objects and file name. * Introduce additional helper test methods to reduce verbosity of s3transfer crt test and improve reusability. --- .../next-release/enhancement-s3cp-40300.json | 5 + awscli/customizations/s3/factory.py | 2 - awscli/s3transfer/crt.py | 168 ++++++++--- awscli/topics/s3-config.rst | 3 - tests/functional/s3/__init__.py | 22 +- tests/functional/s3/test_cp_command.py | 33 ++- tests/functional/s3transfer/test_crt.py | 280 +++++++++++++++--- tests/integration/s3transfer/test_crt.py | 144 ++++++++- tests/unit/customizations/s3/test_factory.py | 4 +- tests/unit/s3transfer/test_crt.py | 11 + tests/utils/s3transfer/__init__.py | 3 + 11 files changed, 575 insertions(+), 100 deletions(-) create mode 100644 .changes/next-release/enhancement-s3cp-40300.json diff --git a/.changes/next-release/enhancement-s3cp-40300.json b/.changes/next-release/enhancement-s3cp-40300.json new file mode 100644 index 000000000000..955eaa046b2e --- /dev/null +++ b/.changes/next-release/enhancement-s3cp-40300.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``s3 cp``", + "description": "Support streaming uploads from stdin and streaming downloads to stdout for CRT transfer client" +} diff --git a/awscli/customizations/s3/factory.py b/awscli/customizations/s3/factory.py index a850b7679b52..6c47a1a977b7 100644 --- a/awscli/customizations/s3/factory.py +++ b/awscli/customizations/s3/factory.py @@ -70,8 +70,6 @@ def create_transfer_manager(self, params, runtime_config, def _compute_transfer_client_type(self, params, runtime_config): if params.get('paths_type') == 's3s3': return constants.DEFAULT_TRANSFER_CLIENT - if params.get('is_stream'): - return constants.DEFAULT_TRANSFER_CLIENT return runtime_config.get( 'preferred_transfer_client', constants.DEFAULT_TRANSFER_CLIENT) diff --git a/awscli/s3transfer/crt.py b/awscli/s3transfer/crt.py index 7b5d13013650..eda2985febfb 100644 --- a/awscli/s3transfer/crt.py +++ b/awscli/s3transfer/crt.py @@ -428,19 +428,12 @@ def _crt_request_from_aws_request(self, aws_request): headers_list.append((name, str(value, 'utf-8'))) crt_headers = awscrt.http.HttpHeaders(headers_list) - # CRT requires body (if it exists) to be an I/O stream. - crt_body_stream = None - if aws_request.body: - if hasattr(aws_request.body, 'seek'): - crt_body_stream = aws_request.body - else: - crt_body_stream = BytesIO(aws_request.body) crt_request = awscrt.http.HttpRequest( method=aws_request.method, path=crt_path, headers=crt_headers, - body_stream=crt_body_stream, + body_stream=aws_request.body, ) return crt_request @@ -453,6 +446,25 @@ def _convert_to_crt_http_request(self, botocore_http_request): crt_request.headers.set("host", url_parts.netloc) if crt_request.headers.get('Content-MD5') is not None: crt_request.headers.remove("Content-MD5") + + # In general, the CRT S3 client expects a content length header. It + # only expects a missing content length header if the body is not + # seekable. However, botocore does not set the content length header + # for GetObject API requests and so we set the content length to zero + # to meet the CRT S3 client's expectation that the content length + # header is set even if there is no body. + if crt_request.headers.get('Content-Length') is None: + if botocore_http_request.body is None: + crt_request.headers.add('Content-Length', "0") + + # Botocore sets the Transfer-Encoding header when it cannot determine + # the content length of the request body (e.g. it's not seekable). + # However, CRT does not support this header, but it supports + # non-seekable bodies. So we remove this header to not cause issues + # in the downstream CRT S3 request. + if crt_request.headers.get('Transfer-Encoding') is not None: + crt_request.headers.remove('Transfer-Encoding') + return crt_request def _capture_http_request(self, request, **kwargs): @@ -555,39 +567,20 @@ def __init__(self, crt_request_serializer, os_utils): def get_make_request_args( self, request_type, call_args, coordinator, future, on_done_after_calls ): - recv_filepath = None - send_filepath = None - s3_meta_request_type = getattr( - S3RequestType, request_type.upper(), S3RequestType.DEFAULT + request_args_handler = getattr( + self, + f'_get_make_request_args_{request_type}', + self._default_get_make_request_args, ) - on_done_before_calls = [] - if s3_meta_request_type == S3RequestType.GET_OBJECT: - final_filepath = call_args.fileobj - recv_filepath = self._os_utils.get_temp_filename(final_filepath) - file_ondone_call = RenameTempFileHandler( - coordinator, final_filepath, recv_filepath, self._os_utils - ) - on_done_before_calls.append(file_ondone_call) - elif s3_meta_request_type == S3RequestType.PUT_OBJECT: - send_filepath = call_args.fileobj - data_len = self._os_utils.get_file_size(send_filepath) - call_args.extra_args["ContentLength"] = data_len - - crt_request = self._request_serializer.serialize_http_request( - request_type, future + return request_args_handler( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=[], + on_done_after_calls=on_done_after_calls, ) - return { - 'request': crt_request, - 'type': s3_meta_request_type, - 'recv_filepath': recv_filepath, - 'send_filepath': send_filepath, - 'on_done': self.get_crt_callback( - future, 'done', on_done_before_calls, on_done_after_calls - ), - 'on_progress': self.get_crt_callback(future, 'progress'), - } - def get_crt_callback( self, future, @@ -613,6 +606,97 @@ def invoke_all_callbacks(*args, **kwargs): return invoke_all_callbacks + def _get_make_request_args_put_object( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + send_filepath = None + if isinstance(call_args.fileobj, str): + send_filepath = call_args.fileobj + data_len = self._os_utils.get_file_size(send_filepath) + call_args.extra_args["ContentLength"] = data_len + else: + call_args.extra_args["Body"] = call_args.fileobj + + # Suppress botocore's automatic MD5 calculation by setting an override + # value that will get deleted in the BotocoreCRTRequestSerializer. + # The CRT S3 client is able automatically compute checksums as part of + # requests it makes, and the intention is to configure automatic + # checksums in a future update. + call_args.extra_args["ContentMD5"] = "override-to-be-removed" + + make_request_args = self._default_get_make_request_args( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=on_done_before_calls, + on_done_after_calls=on_done_after_calls, + ) + make_request_args['send_filepath'] = send_filepath + return make_request_args + + def _get_make_request_args_get_object( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + recv_filepath = None + on_body = None + if isinstance(call_args.fileobj, str): + final_filepath = call_args.fileobj + recv_filepath = self._os_utils.get_temp_filename(final_filepath) + on_done_before_calls.append( + RenameTempFileHandler( + coordinator, final_filepath, recv_filepath, self._os_utils + ) + ) + else: + on_body = OnBodyFileObjWriter(call_args.fileobj) + + make_request_args = self._default_get_make_request_args( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=on_done_before_calls, + on_done_after_calls=on_done_after_calls, + ) + make_request_args['recv_filepath'] = recv_filepath + make_request_args['on_body'] = on_body + return make_request_args + + def _default_get_make_request_args( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + return { + 'request': self._request_serializer.serialize_http_request( + request_type, future + ), + 'type': getattr( + S3RequestType, request_type.upper(), S3RequestType.DEFAULT + ), + 'on_done': self.get_crt_callback( + future, 'done', on_done_before_calls, on_done_after_calls + ), + 'on_progress': self.get_crt_callback(future, 'progress'), + } + class RenameTempFileHandler: def __init__(self, coordinator, final_filename, temp_filename, osutil): @@ -642,3 +726,11 @@ def __init__(self, coordinator): def __call__(self, **kwargs): self._coordinator.set_done_callbacks_complete() + + +class OnBodyFileObjWriter: + def __init__(self, fileobj): + self._fileobj = fileobj + + def __call__(self, chunk, **kwargs): + self._fileobj.write(chunk) diff --git a/awscli/topics/s3-config.rst b/awscli/topics/s3-config.rst index 5663482b7bcb..3b2267850e42 100644 --- a/awscli/topics/s3-config.rst +++ b/awscli/topics/s3-config.rst @@ -297,9 +297,6 @@ files to and from S3. Valid choices are: * S3 to S3 copies - Falls back to using the ``default`` transfer client - * Streaming uploads from standard input and downloads to standard output - - Falls back to using ``default`` transfer client. - * Region redirects - Transfers fail for requests sent to a region that does not match the region of the targeted S3 bucket. diff --git a/tests/functional/s3/__init__.py b/tests/functional/s3/__init__.py index d459dd167246..1c868e9efd75 100644 --- a/tests/functional/s3/__init__.py +++ b/tests/functional/s3/__init__.py @@ -436,6 +436,7 @@ def setUp(self): self.mock_crt_client.return_value.make_request.side_effect = \ self.simulate_make_request_side_effect self.files = FileCreator() + self.expected_download_content = b'content' def tearDown(self): super(BaseCRTTransferClientTest, self).tearDown() @@ -456,6 +457,8 @@ def get_config_file_contents(self): def simulate_make_request_side_effect(self, *args, **kwargs): if kwargs.get('recv_filepath'): self.simulate_file_download(kwargs['recv_filepath']) + elif kwargs.get('on_body'): + self.simulate_on_body(kwargs['on_body']) s3_request = FakeCRTS3Request( future=FakeCRTFuture(kwargs.get('on_done')) ) @@ -465,11 +468,14 @@ def simulate_file_download(self, recv_filepath): parent_dir = os.path.dirname(recv_filepath) if not os.path.isdir(parent_dir): os.makedirs(parent_dir) - with open(recv_filepath, 'w') as f: + with open(recv_filepath, 'wb') as f: # The content is arbitrary as most functional tests are just going # to assert the file exists since it is the CRT writing the # data to the file. - f.write('content') + f.write(self.expected_download_content) + + def simulate_on_body(self, on_body_callback): + on_body_callback(chunk=self.expected_download_content, offset=0) def get_crt_make_request_calls(self): return self.mock_crt_client.return_value.make_request.call_args_list @@ -489,7 +495,8 @@ def assert_crt_make_request_call( self, make_request_call, expected_type, expected_host, expected_path, expected_http_method=None, expected_send_filepath=None, - expected_recv_startswith=None): + expected_recv_startswith=None, + expected_body_content=None): make_request_kwargs = make_request_call[1] self.assertEqual( make_request_kwargs['type'], expected_type) @@ -522,6 +529,15 @@ def assert_crt_make_request_call( f"start with {expected_recv_startswith}" ) ) + if expected_body_content is not None: + # Note: The underlying CRT awscrt.io.InputStream does not expose + # a public read method so we have to reach into the private, + # underlying stream to determine the content. We should update + # to use a public interface if a public interface is ever exposed. + self.assertEqual( + make_request_kwargs['request'].body_stream._stream.read(), + expected_body_content + ) class FakeCRTS3Request: diff --git a/tests/functional/s3/test_cp_command.py b/tests/functional/s3/test_cp_command.py index c789da39b353..2864621bd928 100644 --- a/tests/functional/s3/test_cp_command.py +++ b/tests/functional/s3/test_cp_command.py @@ -2116,25 +2116,38 @@ def test_does_not_use_crt_client_for_copies(self): self.assertEqual(self.get_crt_make_request_calls(), []) self.assert_no_remaining_botocore_responses() - def test_does_not_use_crt_client_for_streaming_upload(self): + def test_streaming_upload_using_crt_client(self): cmdline = [ 's3', 'cp', '-', 's3://bucket/key' ] - self.add_botocore_put_object_response() with mock.patch('sys.stdin', BufferedBytesIO(b'foo')): self.run_command(cmdline) - self.assertEqual(self.get_crt_make_request_calls(), []) - self.assert_no_remaining_botocore_responses() + crt_requests = self.get_crt_make_request_calls() + self.assertEqual(len(crt_requests), 1) + self.assert_crt_make_request_call( + crt_requests[0], + expected_type=S3RequestType.PUT_OBJECT, + expected_host=self.get_virtual_s3_host('bucket'), + expected_path='/key', + expected_body_content=b'foo', + ) - def test_does_not_use_crt_client_for_streaming_download(self): + def test_streaming_download_using_crt_client(self): cmdline = [ 's3', 'cp', 's3://bucket/key', '-' ] - self.add_botocore_head_object_response() - self.add_botocore_get_object_response() - self.run_command(cmdline) - self.assertEqual(self.get_crt_make_request_calls(), []) - self.assert_no_remaining_botocore_responses() + result = self.run_command(cmdline) + crt_requests = self.get_crt_make_request_calls() + self.assertEqual(len(crt_requests), 1) + self.assert_crt_make_request_call( + crt_requests[0], + expected_type=S3RequestType.GET_OBJECT, + expected_host=self.get_virtual_s3_host('bucket'), + expected_path='/key', + ) + self.assertEqual( + result.stdout, self.expected_download_content.decode('utf-8') + ) def test_respects_region_parameter(self): filename = self.files.create_file('myfile', 'mycontent') diff --git a/tests/functional/s3transfer/test_crt.py b/tests/functional/s3transfer/test_crt.py index 0ead2959fea1..f09eb12c8059 100644 --- a/tests/functional/s3transfer/test_crt.py +++ b/tests/functional/s3transfer/test_crt.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import fnmatch +import io import threading import time from concurrent.futures import Future @@ -18,7 +19,15 @@ from botocore.session import Session from s3transfer.subscribers import BaseSubscriber -from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest +from tests import ( + HAS_CRT, + FileCreator, + NonSeekableReader, + NonSeekableWriter, + mock, + requires_crt, + unittest, +) if HAS_CRT: import awscrt @@ -60,13 +69,19 @@ def setUp(self): self.region = 'us-west-2' self.bucket = "test_bucket" self.key = "test_key" + self.expected_content = b'my content' + self.expected_download_content = b'new content' self.files = FileCreator() - self.filename = self.files.create_file('myfile', 'my content') + self.filename = self.files.create_file( + 'myfile', self.expected_content, mode='wb' + ) self.expected_path = "/" + self.bucket + "/" + self.key self.expected_host = "s3.%s.amazonaws.com" % (self.region) self.s3_request = mock.Mock(awscrt.s3.S3Request) self.s3_crt_client = mock.Mock(awscrt.s3.S3Client) - self.s3_crt_client.make_request.return_value = self.s3_request + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) self.session = Session() self.session.set_config_variable('region', self.region) self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( @@ -81,6 +96,44 @@ def setUp(self): def tearDown(self): self.files.remove_all() + def _assert_expected_crt_http_request( + self, + crt_http_request, + expected_http_method='GET', + expected_host=None, + expected_path=None, + expected_body_content=None, + expected_content_length=None, + expected_missing_headers=None, + ): + if expected_host is None: + expected_host = self.expected_host + if expected_path is None: + expected_path = self.expected_path + self.assertEqual(crt_http_request.method, expected_http_method) + self.assertEqual(crt_http_request.headers.get("host"), expected_host) + self.assertEqual(crt_http_request.path, expected_path) + if expected_body_content is not None: + # Note: The underlying CRT awscrt.io.InputStream does not expose + # a public read method so we have to reach into the private, + # underlying stream to determine the content. We should update + # to use a public interface if a public interface is ever exposed. + self.assertEqual( + crt_http_request.body_stream._stream.read(), + expected_body_content, + ) + if expected_content_length is not None: + self.assertEqual( + crt_http_request.headers.get('Content-Length'), + str(expected_content_length), + ) + if expected_missing_headers is not None: + header_names = [ + header[0] for header in crt_http_request.headers + ] + for expected_missing_header in expected_missing_headers: + self.assertNotIn(expected_missing_header, header_names) + def _assert_subscribers_called(self, expected_future=None): self.assertTrue(self.record_subscriber.on_queued_called) self.assertTrue(self.record_subscriber.on_done_called) @@ -99,47 +152,125 @@ def _invoke_done_callbacks(self, **kwargs): on_done(error=None) def _simulate_file_download(self, recv_filepath): - self.files.create_file(recv_filepath, "fake response") + self.files.create_file( + recv_filepath, self.expected_download_content, mode='wb' + ) + + def _simulate_on_body_download(self, on_body_callback): + on_body_callback(chunk=self.expected_download_content, offset=0) def _simulate_make_request_side_effect(self, **kwargs): if kwargs.get('recv_filepath'): self._simulate_file_download(kwargs['recv_filepath']) + if kwargs.get('on_body'): + self._simulate_on_body_download(kwargs['on_body']) self._invoke_done_callbacks() - return mock.DEFAULT + return self.s3_request def test_upload(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect - ) future = self.transfer_manager.upload( self.filename, self.bucket, self.key, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertEqual(callargs_kwargs["send_filepath"], self.filename) - self.assertIsNone(callargs_kwargs["recv_filepath"]) + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': self.filename, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_content_length=len(self.expected_content), + expected_missing_headers=['Content-MD5'], ) - crt_request = callargs_kwargs["request"] - self.assertEqual("PUT", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) - def test_download(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect + def test_upload_from_seekable_stream(self): + with open(self.filename, 'rb') as f: + future = self.transfer_manager.upload( + f, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_body_content=self.expected_content, + expected_content_length=len(self.expected_content), + expected_missing_headers=['Content-MD5'], + ) + self._assert_subscribers_called(future) + + def test_upload_from_nonseekable_stream(self): + nonseekable_stream = NonSeekableReader(self.expected_content) + future = self.transfer_manager.upload( + nonseekable_stream, + self.bucket, + self.key, + {}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_body_content=self.expected_content, + expected_missing_headers=[ + 'Content-MD5', + 'Content-Length', + 'Transfer-Encoding', + ], + ) + self._assert_subscribers_called(future) + + def test_download(self): future = self.transfer_manager.download( self.bucket, self.key, self.filename, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': mock.ANY, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': None, + }, + ) # the recv_filepath will be set to a temporary file path with some # random suffix self.assertTrue( @@ -148,42 +279,109 @@ def test_download(self): f'{self.filename}.*', ) ) - self.assertIsNone(callargs_kwargs["send_filepath"]) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, + ) + self._assert_subscribers_called(future) + with open(self.filename, 'rb') as f: + # Check the fake response overwrites the file because of download + self.assertEqual(f.read(), self.expected_download_content) + + def test_download_to_seekable_stream(self): + with open(self.filename, 'wb') as f: + future = self.transfer_manager.download( + self.bucket, self.key, f, {}, [self.record_subscriber] + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, ) - crt_request = callargs_kwargs["request"] - self.assertEqual("GET", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) with open(self.filename, 'rb') as f: # Check the fake response overwrites the file because of download - self.assertEqual(f.read(), b'fake response') + self.assertEqual(f.read(), self.expected_download_content) - def test_delete(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect + def test_download_to_nonseekable_stream(self): + underlying_stream = io.BytesIO() + nonseekable_stream = NonSeekableWriter(underlying_stream) + future = self.transfer_manager.download( + self.bucket, + self.key, + nonseekable_stream, + {}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, + ) + self._assert_subscribers_called(future) + self.assertEqual( + underlying_stream.getvalue(), self.expected_download_content ) + + def test_delete(self): future = self.transfer_manager.delete( self.bucket, self.key, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertIsNone(callargs_kwargs["send_filepath"]) - self.assertIsNone(callargs_kwargs["recv_filepath"]) + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.DEFAULT, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='DELETE', + expected_content_length=0, ) - crt_request = callargs_kwargs["request"] - self.assertEqual("DELETE", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) def test_blocks_when_max_requests_processes_reached(self): + self.s3_crt_client.make_request.return_value = self.s3_request + # We simulate blocking by not invoking the on_done callbacks for + # all of the requests we send. The default side effect invokes all + # callbacks so we need to unset the side effect to avoid on_done from + # being called in the child threads. + self.s3_crt_client.make_request.side_effect = None futures = [] callargs = (self.bucket, self.key, self.filename, {}, []) max_request_processes = 128 # the hard coded max processes diff --git a/tests/integration/s3transfer/test_crt.py b/tests/integration/s3transfer/test_crt.py index 9f6c533dbc8a..9580aecf50f8 100644 --- a/tests/integration/s3transfer/test_crt.py +++ b/tests/integration/s3transfer/test_crt.py @@ -11,11 +11,18 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import glob +import io import os from s3transfer.subscribers import BaseSubscriber from s3transfer.utils import OSUtils -from tests import HAS_CRT, assert_files_equal, requires_crt +from tests import ( + HAS_CRT, + NonSeekableReader, + NonSeekableWriter, + assert_files_equal, + requires_crt, +) from tests.integration.s3transfer import BaseTransferManagerIntegTest if HAS_CRT: @@ -44,6 +51,11 @@ def on_done(self, **kwargs): class TestCRTS3Transfers(BaseTransferManagerIntegTest): """Tests for the high level s3transfer based on CRT implementation.""" + def setUp(self): + super().setUp() + self.s3_key = 's3key.txt' + self.download_path = os.path.join(self.files.rootdir, 'download.txt') + def _create_s3_transfer(self): self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( self.session, client_kwargs={'region_name': self.region} @@ -58,6 +70,40 @@ def _create_s3_transfer(self): self.s3_crt_client, self.request_serializer ) + def _upload_with_crt_transfer_manager(self, fileobj, key=None): + if key is None: + key = self.s3_key + self.addCleanup(self.delete_object, key) + with self._create_s3_transfer() as transfer: + future = transfer.upload( + fileobj, + self.bucket_name, + key, + subscribers=[self.record_subscriber], + ) + future.result() + + def _download_with_crt_transfer_manager(self, fileobj, key=None): + if key is None: + key = self.s3_key + self.addCleanup(self.delete_object, key) + with self._create_s3_transfer() as transfer: + future = transfer.download( + self.bucket_name, + key, + fileobj, + subscribers=[self.record_subscriber], + ) + future.result() + + def _assert_expected_s3_object(self, key, expected_size=None): + self.assertTrue(self.object_exists(key)) + if expected_size is not None: + response = self.client.head_object( + Bucket=self.bucket_name, Key=key + ) + self.assertEqual(response['ContentLength'], expected_size) + def _assert_has_public_read_acl(self, response): grants = response['Grants'] public_read = [ @@ -176,6 +222,43 @@ def test_upload_file_above_threshold_with_ssec(self): self.assertEqual(response['SSECustomerAlgorithm'], 'AES256') self._assert_subscribers_called(file_size) + def test_upload_seekable_stream(self): + size = 1024 * 1024 + self._upload_with_crt_transfer_manager(io.BytesIO(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_multipart_upload_seekable_stream(self): + size = 20 * 1024 * 1024 + self._upload_with_crt_transfer_manager(io.BytesIO(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_nonseekable_stream(self): + size = 1024 * 1024 + self._upload_with_crt_transfer_manager(NonSeekableReader(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_multipart_upload_nonseekable_stream(self): + size = 20 * 1024 * 1024 + self._upload_with_crt_transfer_manager(NonSeekableReader(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_empty_file(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self._upload_with_crt_transfer_manager(filename) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_empty_stream(self): + size = 0 + self._upload_with_crt_transfer_manager(io.BytesIO(b'')) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + def test_can_send_extra_params_on_download(self): # We're picking the customer provided sse feature # of S3 to test the extra_args functionality of @@ -244,6 +327,65 @@ def test_download_above_threshold(self): file_size = self.osutil.get_file_size(download_path) self._assert_subscribers_called(file_size) + def test_download_seekable_stream(self): + size = 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_multipart_download_seekable_stream(self): + size = 20 * 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_nonseekable_stream(self): + size = 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(NonSeekableWriter(f)) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_multipart_download_nonseekable_stream(self): + size = 20 * 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(NonSeekableWriter(f)) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_empty_file(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + self._download_with_crt_transfer_manager(self.download_path) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_empty_stream(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + def test_delete(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( diff --git a/tests/unit/customizations/s3/test_factory.py b/tests/unit/customizations/s3/test_factory.py index 41eef50332aa..f1892359cfff 100644 --- a/tests/unit/customizations/s3/test_factory.py +++ b/tests/unit/customizations/s3/test_factory.py @@ -196,13 +196,13 @@ def test_creates_default_manager_for_copies(self): self.params, self.runtime_config) self.assert_is_default_manager(transfer_manager) - def test_creates_default_manager_when_streaming_operation(self): + def test_creates_crt_manager_when_streaming_operation(self): self.params['is_stream'] = True self.runtime_config = self.get_runtime_config( preferred_transfer_client='crt') transfer_manager = self.factory.create_transfer_manager( self.params, self.runtime_config) - self.assert_is_default_manager(transfer_manager) + self.assert_is_crt_manager(transfer_manager) @mock.patch('s3transfer.crt.S3Client') def test_uses_region_parameter_for_crt_manager(self, mock_crt_client): diff --git a/tests/unit/s3transfer/test_crt.py b/tests/unit/s3transfer/test_crt.py index b6ad3245e71d..aadd38276ac6 100644 --- a/tests/unit/s3transfer/test_crt.py +++ b/tests/unit/s3transfer/test_crt.py @@ -10,6 +10,8 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +import io + from botocore.credentials import CredentialResolver, ReadOnlyCredentials from botocore.session import Session @@ -171,3 +173,12 @@ def test_set_exception_can_override_previous_exception(self): self.future.set_exception(CustomFutureException()) with self.assertRaises(CustomFutureException): self.future.result() + + +@requires_crt +class TestOnBodyFileObjWriter(unittest.TestCase): + def test_call(self): + fileobj = io.BytesIO() + writer = s3transfer.crt.OnBodyFileObjWriter(fileobj) + writer(chunk=b'content') + self.assertEqual(fileobj.getvalue(), b'content') diff --git a/tests/utils/s3transfer/__init__.py b/tests/utils/s3transfer/__init__.py index 9ebbf1e9de20..6fe44f6099dc 100644 --- a/tests/utils/s3transfer/__init__.py +++ b/tests/utils/s3transfer/__init__.py @@ -500,6 +500,9 @@ def write(self, b): def read(self, n=-1): return self._data.read(n) + def readinto(self, b): + return self._data.readinto(b) + class NonSeekableWriter(io.RawIOBase): def __init__(self, fileobj):