diff --git a/.changes/next-release/enhancement-Botocore-82899.json b/.changes/next-release/enhancement-Botocore-82899.json new file mode 100644 index 00000000..ee038c25 --- /dev/null +++ b/.changes/next-release/enhancement-Botocore-82899.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``Botocore``", + "description": "S3Transfer now requires Botocore >=1.32.7" +} diff --git a/.changes/next-release/enhancement-crt-28261.json b/.changes/next-release/enhancement-crt-28261.json new file mode 100644 index 00000000..ee65dd62 --- /dev/null +++ b/.changes/next-release/enhancement-crt-28261.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``crt``", + "description": "Update ``target_throughput`` defaults. If not configured, s3transfer will use the AWS CRT to attempt to determine a recommended target throughput to use based on the system. If there is no recommended throughput, s3transfer now falls back to ten gigabits per second." +} diff --git a/.changes/next-release/enhancement-crt-30257.json b/.changes/next-release/enhancement-crt-30257.json new file mode 100644 index 00000000..dca8d87d --- /dev/null +++ b/.changes/next-release/enhancement-crt-30257.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``crt``", + "description": "Automatically configure CRC32 checksums for uploads and checksum validation for downloads through the CRT transfer manager." +} diff --git a/.changes/next-release/enhancement-crt-51520.json b/.changes/next-release/enhancement-crt-51520.json new file mode 100644 index 00000000..2bf4a0cf --- /dev/null +++ b/.changes/next-release/enhancement-crt-51520.json @@ -0,0 +1,5 @@ +{ + "type": "enhancement", + "category": "``crt``", + "description": "Add support for uploading and downloading file-like objects using CRT transfer manager. It supports both seekable and non-seekable file-like objects." +} diff --git a/.changes/next-release/feature-crt-5777.json b/.changes/next-release/feature-crt-5777.json new file mode 100644 index 00000000..4b0391a3 --- /dev/null +++ b/.changes/next-release/feature-crt-5777.json @@ -0,0 +1,5 @@ +{ + "type": "feature", + "category": "``crt``", + "description": "S3transfer now supports a wider range of CRT functionality for uploads to improve throughput in the CLI/Boto3." +} diff --git a/s3transfer/crt.py b/s3transfer/crt.py index 7b5d1301..24fa7976 100644 --- a/s3transfer/crt.py +++ b/s3transfer/crt.py @@ -25,49 +25,58 @@ EventLoopGroup, TlsContextOptions, ) -from awscrt.s3 import S3Client, S3RequestTlsMode, S3RequestType +from awscrt.s3 import ( + S3Client, + S3RequestTlsMode, + S3RequestType, + S3ResponseError, + get_recommended_throughput_target_gbps, +) from botocore import UNSIGNED from botocore.compat import urlsplit from botocore.config import Config from botocore.exceptions import NoCredentialsError -from s3transfer.constants import GB, MB +from s3transfer.constants import MB from s3transfer.exceptions import TransferNotDoneError from s3transfer.futures import BaseTransferFuture, BaseTransferMeta from s3transfer.utils import CallArgs, OSUtils, get_callbacks logger = logging.getLogger(__name__) - -class CRTCredentialProviderAdapter: - def __init__(self, botocore_credential_provider): - self._botocore_credential_provider = botocore_credential_provider - self._loaded_credentials = None - self._lock = threading.Lock() - - def __call__(self): - credentials = self._get_credentials().get_frozen_credentials() - return AwsCredentials( - credentials.access_key, credentials.secret_key, credentials.token - ) - - def _get_credentials(self): - with self._lock: - if self._loaded_credentials is None: - loaded_creds = ( - self._botocore_credential_provider.load_credentials() - ) - if loaded_creds is None: - raise NoCredentialsError() - self._loaded_credentials = loaded_creds - return self._loaded_credentials +CRT_S3_PROCESS_LOCK = None + + +def acquire_crt_s3_process_lock(name): + # Currently, the CRT S3 client performs best when there is only one + # instance of it running on a host. This lock allows an application to + # signal across processes whether there is another process of the same + # application using the CRT S3 client and prevent spawning more than one + # CRT S3 clients running on the system for that application. + # + # NOTE: When acquiring the CRT process lock, the lock automatically is + # released when the lock object is garbage collected. So, the CRT process + # lock is set as a global so that it is not unintentionally garbage + # collected/released if reference of the lock is lost. + global CRT_S3_PROCESS_LOCK + if CRT_S3_PROCESS_LOCK is None: + crt_lock = awscrt.s3.CrossProcessLock(name) + try: + crt_lock.acquire() + except RuntimeError: + # If there is another process that is holding the lock, the CRT + # returns a RuntimeError. We return None here to signal that our + # current process was not able to acquire the lock. + return None + CRT_S3_PROCESS_LOCK = crt_lock + return CRT_S3_PROCESS_LOCK def create_s3_crt_client( region, - botocore_credential_provider=None, + crt_credentials_provider=None, num_threads=None, - target_throughput=5 * GB / 8, + target_throughput=None, part_size=8 * MB, use_ssl=True, verify=None, @@ -76,18 +85,24 @@ def create_s3_crt_client( :type region: str :param region: The region used for signing - :type botocore_credential_provider: - Optional[botocore.credentials.CredentialResolver] - :param botocore_credential_provider: Provide credentials for CRT - to sign the request if not set, the request will not be signed + :type crt_credentials_provider: + Optional[awscrt.auth.AwsCredentialsProvider] + :param crt_credentials_provider: CRT AWS credentials provider + to use to sign requests. If not set, requests will not be signed. :type num_threads: Optional[int] :param num_threads: Number of worker threads generated. Default is the number of processors in the machine. :type target_throughput: Optional[int] - :param target_throughput: Throughput target in Bytes. - Default is 0.625 GB/s (which translates to 5 Gb/s). + :param target_throughput: Throughput target in bytes per second. + By default, CRT will automatically attempt to choose a target + throughput that matches the system's maximum network throughput. + Currently, if CRT is unable to determine the maximum network + throughput, a fallback target throughput of ``1_250_000_000`` bytes + per second (which translates to 10 gigabits per second, or 1.16 + gibibytes per second) is used. To set a specific target + throughput, set a value for this parameter. :type part_size: Optional[int] :param part_size: Size, in Bytes, of parts that files will be downloaded @@ -113,7 +128,6 @@ def create_s3_crt_client( event_loop_group = EventLoopGroup(num_threads) host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) - provider = None tls_connection_options = None tls_mode = ( @@ -129,19 +143,13 @@ def create_s3_crt_client( tls_ctx_options.verify_peer = False client_tls_option = ClientTlsContext(tls_ctx_options) tls_connection_options = client_tls_option.new_connection_options() - if botocore_credential_provider: - credentails_provider_adapter = CRTCredentialProviderAdapter( - botocore_credential_provider - ) - provider = AwsCredentialsProvider.new_delegate( - credentails_provider_adapter - ) - - target_gbps = target_throughput * 8 / GB + target_gbps = _get_crt_throughput_target_gbps( + provided_throughput_target_bytes=target_throughput + ) return S3Client( bootstrap=bootstrap, region=region, - credential_provider=provider, + credential_provider=crt_credentials_provider, part_size=part_size, tls_mode=tls_mode, tls_connection_options=tls_connection_options, @@ -149,6 +157,24 @@ def create_s3_crt_client( ) +def _get_crt_throughput_target_gbps(provided_throughput_target_bytes=None): + if provided_throughput_target_bytes is None: + target_gbps = get_recommended_throughput_target_gbps() + logger.debug( + 'Recommended CRT throughput target in gbps: %s', target_gbps + ) + if target_gbps is None: + target_gbps = 10.0 + else: + # NOTE: The GB constant in s3transfer is technically a gibibyte. The + # GB constant is not used here because the CRT interprets gigabits + # for networking as a base power of 10 + # (i.e. 1000 ** 3 instead of 1024 ** 3). + target_gbps = provided_throughput_target_bytes * 8 / 1_000_000_000 + logger.debug('Using CRT throughput target in gbps: %s', target_gbps) + return target_gbps + + class CRTTransferManager: def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): """A transfer manager interface for Amazon S3 on CRT s3 client. @@ -171,6 +197,9 @@ def __init__(self, crt_s3_client, crt_request_serializer, osutil=None): self._s3_args_creator = S3ClientArgsCreator( crt_request_serializer, self._osutil ) + self._crt_exception_translator = ( + crt_request_serializer.translate_crt_exception + ) self._future_coordinators = [] self._semaphore = threading.Semaphore(128) # not configurable # A counter to create unique id's for each transfer submitted. @@ -206,6 +235,7 @@ def upload(self, fileobj, bucket, key, extra_args=None, subscribers=None): extra_args = {} if subscribers is None: subscribers = {} + self._validate_checksum_algorithm_supported(extra_args) callargs = CallArgs( bucket=bucket, key=key, @@ -231,6 +261,17 @@ def delete(self, bucket, key, extra_args=None, subscribers=None): def shutdown(self, cancel=False): self._shutdown(cancel) + def _validate_checksum_algorithm_supported(self, extra_args): + checksum_algorithm = extra_args.get('ChecksumAlgorithm') + if checksum_algorithm is None: + return + supported_algorithms = list(awscrt.s3.S3ChecksumAlgorithm.__members__) + if checksum_algorithm.upper() not in supported_algorithms: + raise ValueError( + f'ChecksumAlgorithm: {checksum_algorithm} not supported. ' + f'Supported algorithms are: {supported_algorithms}' + ) + def _cancel_transfers(self): for coordinator in self._future_coordinators: if not coordinator.done(): @@ -262,7 +303,10 @@ def _release_semaphore(self, **kwargs): def _submit_transfer(self, request_type, call_args): on_done_after_calls = [self._release_semaphore] - coordinator = CRTTransferCoordinator(transfer_id=self._id_counter) + coordinator = CRTTransferCoordinator( + transfer_id=self._id_counter, + exception_translator=self._crt_exception_translator, + ) components = { 'meta': CRTTransferMeta(self._id_counter, call_args), 'coordinator': coordinator, @@ -373,6 +417,9 @@ def serialize_http_request(self, transfer_type, future): """ raise NotImplementedError('serialize_http_request()') + def translate_crt_exception(self, exception): + raise NotImplementedError('translate_crt_exception()') + class BotocoreCRTRequestSerializer(BaseCRTRequestSerializer): def __init__(self, session, client_kwargs=None): @@ -428,19 +475,12 @@ def _crt_request_from_aws_request(self, aws_request): headers_list.append((name, str(value, 'utf-8'))) crt_headers = awscrt.http.HttpHeaders(headers_list) - # CRT requires body (if it exists) to be an I/O stream. - crt_body_stream = None - if aws_request.body: - if hasattr(aws_request.body, 'seek'): - crt_body_stream = aws_request.body - else: - crt_body_stream = BytesIO(aws_request.body) crt_request = awscrt.http.HttpRequest( method=aws_request.method, path=crt_path, headers=crt_headers, - body_stream=crt_body_stream, + body_stream=aws_request.body, ) return crt_request @@ -453,6 +493,25 @@ def _convert_to_crt_http_request(self, botocore_http_request): crt_request.headers.set("host", url_parts.netloc) if crt_request.headers.get('Content-MD5') is not None: crt_request.headers.remove("Content-MD5") + + # In general, the CRT S3 client expects a content length header. It + # only expects a missing content length header if the body is not + # seekable. However, botocore does not set the content length header + # for GetObject API requests and so we set the content length to zero + # to meet the CRT S3 client's expectation that the content length + # header is set even if there is no body. + if crt_request.headers.get('Content-Length') is None: + if botocore_http_request.body is None: + crt_request.headers.add('Content-Length', "0") + + # Botocore sets the Transfer-Encoding header when it cannot determine + # the content length of the request body (e.g. it's not seekable). + # However, CRT does not support this header, but it supports + # non-seekable bodies. So we remove this header to not cause issues + # in the downstream CRT S3 request. + if crt_request.headers.get('Transfer-Encoding') is not None: + crt_request.headers.remove('Transfer-Encoding') + return crt_request def _capture_http_request(self, request, **kwargs): @@ -484,6 +543,40 @@ def serialize_http_request(self, transfer_type, future): crt_request = self._convert_to_crt_http_request(botocore_http_request) return crt_request + def translate_crt_exception(self, exception): + if isinstance(exception, S3ResponseError): + return self._translate_crt_s3_response_error(exception) + else: + return None + + def _translate_crt_s3_response_error(self, s3_response_error): + status_code = s3_response_error.status_code + if status_code < 301: + # Botocore's exception parsing only + # runs on status codes >= 301 + return None + + headers = {k: v for k, v in s3_response_error.headers} + operation_name = s3_response_error.operation_name + if operation_name is not None: + service_model = self._client.meta.service_model + shape = service_model.operation_model(operation_name).output_shape + else: + shape = None + + response_dict = { + 'headers': botocore.awsrequest.HeadersDict(headers), + 'status_code': status_code, + 'body': s3_response_error.body, + } + parsed_response = self._client._response_parser.parse( + response_dict, shape=shape + ) + + error_code = parsed_response.get("Error", {}).get("Code") + error_class = self._client.exceptions.from_code(error_code) + return error_class(parsed_response, operation_name=operation_name) + class FakeRawResponse(BytesIO): def stream(self, amt=1024, decode_content=None): @@ -494,11 +587,33 @@ def stream(self, amt=1024, decode_content=None): yield chunk +class BotocoreCRTCredentialsWrapper: + def __init__(self, resolved_botocore_credentials): + self._resolved_credentials = resolved_botocore_credentials + + def __call__(self): + credentials = self._get_credentials().get_frozen_credentials() + return AwsCredentials( + credentials.access_key, credentials.secret_key, credentials.token + ) + + def to_crt_credentials_provider(self): + return AwsCredentialsProvider.new_delegate(self) + + def _get_credentials(self): + if self._resolved_credentials is None: + raise NoCredentialsError() + return self._resolved_credentials + + class CRTTransferCoordinator: """A helper class for managing CRTTransferFuture""" - def __init__(self, transfer_id=None, s3_request=None): + def __init__( + self, transfer_id=None, s3_request=None, exception_translator=None + ): self.transfer_id = transfer_id + self._exception_translator = exception_translator self._s3_request = s3_request self._lock = threading.Lock() self._exception = None @@ -531,11 +646,28 @@ def result(self, timeout=None): self._crt_future.result(timeout) except KeyboardInterrupt: self.cancel() + self._crt_future.result(timeout) raise + except Exception as e: + self.handle_exception(e) finally: if self._s3_request: self._s3_request = None - self._crt_future.result(timeout) + + def handle_exception(self, exc): + translated_exc = None + if self._exception_translator: + try: + translated_exc = self._exception_translator(exc) + except Exception as e: + # Bail out if we hit an issue translating + # and raise the original error. + logger.debug("Unable to translate exception.", exc_info=e) + pass + if translated_exc is not None: + raise translated_exc from exc + else: + raise exc def done(self): if self._crt_future is None: @@ -555,39 +687,20 @@ def __init__(self, crt_request_serializer, os_utils): def get_make_request_args( self, request_type, call_args, coordinator, future, on_done_after_calls ): - recv_filepath = None - send_filepath = None - s3_meta_request_type = getattr( - S3RequestType, request_type.upper(), S3RequestType.DEFAULT + request_args_handler = getattr( + self, + f'_get_make_request_args_{request_type}', + self._default_get_make_request_args, ) - on_done_before_calls = [] - if s3_meta_request_type == S3RequestType.GET_OBJECT: - final_filepath = call_args.fileobj - recv_filepath = self._os_utils.get_temp_filename(final_filepath) - file_ondone_call = RenameTempFileHandler( - coordinator, final_filepath, recv_filepath, self._os_utils - ) - on_done_before_calls.append(file_ondone_call) - elif s3_meta_request_type == S3RequestType.PUT_OBJECT: - send_filepath = call_args.fileobj - data_len = self._os_utils.get_file_size(send_filepath) - call_args.extra_args["ContentLength"] = data_len - - crt_request = self._request_serializer.serialize_http_request( - request_type, future + return request_args_handler( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=[], + on_done_after_calls=on_done_after_calls, ) - return { - 'request': crt_request, - 'type': s3_meta_request_type, - 'recv_filepath': recv_filepath, - 'send_filepath': send_filepath, - 'on_done': self.get_crt_callback( - future, 'done', on_done_before_calls, on_done_after_calls - ), - 'on_progress': self.get_crt_callback(future, 'progress'), - } - def get_crt_callback( self, future, @@ -613,6 +726,106 @@ def invoke_all_callbacks(*args, **kwargs): return invoke_all_callbacks + def _get_make_request_args_put_object( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + send_filepath = None + if isinstance(call_args.fileobj, str): + send_filepath = call_args.fileobj + data_len = self._os_utils.get_file_size(send_filepath) + call_args.extra_args["ContentLength"] = data_len + else: + call_args.extra_args["Body"] = call_args.fileobj + + checksum_algorithm = call_args.extra_args.pop( + 'ChecksumAlgorithm', 'CRC32' + ).upper() + checksum_config = awscrt.s3.S3ChecksumConfig( + algorithm=awscrt.s3.S3ChecksumAlgorithm[checksum_algorithm], + location=awscrt.s3.S3ChecksumLocation.TRAILER, + ) + # Suppress botocore's automatic MD5 calculation by setting an override + # value that will get deleted in the BotocoreCRTRequestSerializer. + # As part of the CRT S3 request, we request the CRT S3 client to + # automatically add trailing checksums to its uploads. + call_args.extra_args["ContentMD5"] = "override-to-be-removed" + + make_request_args = self._default_get_make_request_args( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=on_done_before_calls, + on_done_after_calls=on_done_after_calls, + ) + make_request_args['send_filepath'] = send_filepath + make_request_args['checksum_config'] = checksum_config + return make_request_args + + def _get_make_request_args_get_object( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + recv_filepath = None + on_body = None + checksum_config = awscrt.s3.S3ChecksumConfig(validate_response=True) + if isinstance(call_args.fileobj, str): + final_filepath = call_args.fileobj + recv_filepath = self._os_utils.get_temp_filename(final_filepath) + on_done_before_calls.append( + RenameTempFileHandler( + coordinator, final_filepath, recv_filepath, self._os_utils + ) + ) + else: + on_body = OnBodyFileObjWriter(call_args.fileobj) + + make_request_args = self._default_get_make_request_args( + request_type=request_type, + call_args=call_args, + coordinator=coordinator, + future=future, + on_done_before_calls=on_done_before_calls, + on_done_after_calls=on_done_after_calls, + ) + make_request_args['recv_filepath'] = recv_filepath + make_request_args['on_body'] = on_body + make_request_args['checksum_config'] = checksum_config + return make_request_args + + def _default_get_make_request_args( + self, + request_type, + call_args, + coordinator, + future, + on_done_before_calls, + on_done_after_calls, + ): + return { + 'request': self._request_serializer.serialize_http_request( + request_type, future + ), + 'type': getattr( + S3RequestType, request_type.upper(), S3RequestType.DEFAULT + ), + 'on_done': self.get_crt_callback( + future, 'done', on_done_before_calls, on_done_after_calls + ), + 'on_progress': self.get_crt_callback(future, 'progress'), + } + class RenameTempFileHandler: def __init__(self, coordinator, final_filename, temp_filename, osutil): @@ -642,3 +855,11 @@ def __init__(self, coordinator): def __call__(self, **kwargs): self._coordinator.set_done_callbacks_complete() + + +class OnBodyFileObjWriter: + def __init__(self, fileobj): + self._fileobj = fileobj + + def __call__(self, chunk, **kwargs): + self._fileobj.write(chunk) diff --git a/setup.cfg b/setup.cfg index fd892717..2e9978dc 100644 --- a/setup.cfg +++ b/setup.cfg @@ -3,10 +3,10 @@ universal = 0 [metadata] requires_dist = - botocore>=1.12.36,<2.0a.0 + botocore>=1.32.6,<2.0a.0 [options.extras_require] -crt = botocore[crt]>=1.20.29,<2.0a0 +crt = botocore[crt]>=1.32.6,<2.0a0 [flake8] ignore = E203,E226,E501,W503,W504 diff --git a/setup.py b/setup.py index c194bdbe..6c1432f0 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ requires = [ - 'botocore>=1.12.36,<2.0a.0', + 'botocore>=1.32.6,<2.0a.0', ] @@ -30,7 +30,7 @@ def get_version(): include_package_data=True, install_requires=requires, extras_require={ - 'crt': 'botocore[crt]>=1.20.29,<2.0a.0', + 'crt': 'botocore[crt]>=1.32.6,<2.0a.0', }, license="Apache License 2.0", python_requires=">= 3.7", diff --git a/tests/__init__.py b/tests/__init__.py index e36c4936..03590fef 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -509,6 +509,9 @@ def write(self, b): def read(self, n=-1): return self._data.read(n) + def readinto(self, b): + return self._data.readinto(b) + class NonSeekableWriter(io.RawIOBase): def __init__(self, fileobj): diff --git a/tests/functional/test_crt.py b/tests/functional/test_crt.py index 0ead2959..c56ea301 100644 --- a/tests/functional/test_crt.py +++ b/tests/functional/test_crt.py @@ -11,6 +11,7 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import fnmatch +import io import threading import time from concurrent.futures import Future @@ -18,7 +19,15 @@ from botocore.session import Session from s3transfer.subscribers import BaseSubscriber -from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest +from tests import ( + HAS_CRT, + FileCreator, + NonSeekableReader, + NonSeekableWriter, + mock, + requires_crt, + unittest, +) if HAS_CRT: import awscrt @@ -60,13 +69,19 @@ def setUp(self): self.region = 'us-west-2' self.bucket = "test_bucket" self.key = "test_key" + self.expected_content = b'my content' + self.expected_download_content = b'new content' self.files = FileCreator() - self.filename = self.files.create_file('myfile', 'my content') + self.filename = self.files.create_file( + 'myfile', self.expected_content, mode='wb' + ) self.expected_path = "/" + self.bucket + "/" + self.key self.expected_host = "s3.%s.amazonaws.com" % (self.region) self.s3_request = mock.Mock(awscrt.s3.S3Request) self.s3_crt_client = mock.Mock(awscrt.s3.S3Client) - self.s3_crt_client.make_request.return_value = self.s3_request + self.s3_crt_client.make_request.side_effect = ( + self._simulate_make_request_side_effect + ) self.session = Session() self.session.set_config_variable('region', self.region) self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( @@ -81,6 +96,44 @@ def setUp(self): def tearDown(self): self.files.remove_all() + def _assert_expected_crt_http_request( + self, + crt_http_request, + expected_http_method='GET', + expected_host=None, + expected_path=None, + expected_body_content=None, + expected_content_length=None, + expected_missing_headers=None, + ): + if expected_host is None: + expected_host = self.expected_host + if expected_path is None: + expected_path = self.expected_path + self.assertEqual(crt_http_request.method, expected_http_method) + self.assertEqual(crt_http_request.headers.get("host"), expected_host) + self.assertEqual(crt_http_request.path, expected_path) + if expected_body_content is not None: + # Note: The underlying CRT awscrt.io.InputStream does not expose + # a public read method so we have to reach into the private, + # underlying stream to determine the content. We should update + # to use a public interface if a public interface is ever exposed. + self.assertEqual( + crt_http_request.body_stream._stream.read(), + expected_body_content, + ) + if expected_content_length is not None: + self.assertEqual( + crt_http_request.headers.get('Content-Length'), + str(expected_content_length), + ) + if expected_missing_headers is not None: + header_names = [ + header[0].lower() for header in crt_http_request.headers + ] + for expected_missing_header in expected_missing_headers: + self.assertNotIn(expected_missing_header.lower(), header_names) + def _assert_subscribers_called(self, expected_future=None): self.assertTrue(self.record_subscriber.on_queued_called) self.assertTrue(self.record_subscriber.on_done_called) @@ -92,6 +145,21 @@ def _assert_subscribers_called(self, expected_future=None): self.record_subscriber.on_done_future, expected_future ) + def _get_expected_upload_checksum_config(self, **overrides): + checksum_config_kwargs = { + 'algorithm': awscrt.s3.S3ChecksumAlgorithm.CRC32, + 'location': awscrt.s3.S3ChecksumLocation.TRAILER, + } + checksum_config_kwargs.update(overrides) + return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs) + + def _get_expected_download_checksum_config(self, **overrides): + checksum_config_kwargs = { + 'validate_response': True, + } + checksum_config_kwargs.update(overrides) + return awscrt.s3.S3ChecksumConfig(**checksum_config_kwargs) + def _invoke_done_callbacks(self, **kwargs): callargs = self.s3_crt_client.make_request.call_args callargs_kwargs = callargs[1] @@ -99,47 +167,213 @@ def _invoke_done_callbacks(self, **kwargs): on_done(error=None) def _simulate_file_download(self, recv_filepath): - self.files.create_file(recv_filepath, "fake response") + self.files.create_file( + recv_filepath, self.expected_download_content, mode='wb' + ) + + def _simulate_on_body_download(self, on_body_callback): + on_body_callback(chunk=self.expected_download_content, offset=0) def _simulate_make_request_side_effect(self, **kwargs): if kwargs.get('recv_filepath'): self._simulate_file_download(kwargs['recv_filepath']) + if kwargs.get('on_body'): + self._simulate_on_body_download(kwargs['on_body']) self._invoke_done_callbacks() - return mock.DEFAULT + return self.s3_request def test_upload(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect - ) future = self.transfer_manager.upload( self.filename, self.bucket, self.key, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertEqual(callargs_kwargs["send_filepath"], self.filename) - self.assertIsNone(callargs_kwargs["recv_filepath"]) + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.PUT_OBJECT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': self.filename, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'checksum_config': self._get_expected_upload_checksum_config(), + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_content_length=len(self.expected_content), + expected_missing_headers=['Content-MD5'], ) - crt_request = callargs_kwargs["request"] - self.assertEqual("PUT", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) - def test_download(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect + def test_upload_from_seekable_stream(self): + with open(self.filename, 'rb') as f: + future = self.transfer_manager.upload( + f, self.bucket, self.key, {}, [self.record_subscriber] + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'checksum_config': self._get_expected_upload_checksum_config(), + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_body_content=self.expected_content, + expected_content_length=len(self.expected_content), + expected_missing_headers=['Content-MD5'], + ) + self._assert_subscribers_called(future) + + def test_upload_from_nonseekable_stream(self): + nonseekable_stream = NonSeekableReader(self.expected_content) + future = self.transfer_manager.upload( + nonseekable_stream, + self.bucket, + self.key, + {}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'checksum_config': self._get_expected_upload_checksum_config(), + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_body_content=self.expected_content, + expected_missing_headers=[ + 'Content-MD5', + 'Content-Length', + 'Transfer-Encoding', + ], + ) + self._assert_subscribers_called(future) + + def test_upload_override_checksum_algorithm(self): + future = self.transfer_manager.upload( + self.filename, + self.bucket, + self.key, + {'ChecksumAlgorithm': 'CRC32C'}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': self.filename, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'checksum_config': self._get_expected_upload_checksum_config( + algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C + ), + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_content_length=len(self.expected_content), + expected_missing_headers=[ + 'Content-MD5', + 'x-amz-sdk-checksum-algorithm', + 'X-Amz-Trailer', + ], + ) + self._assert_subscribers_called(future) + + def test_upload_override_checksum_algorithm_accepts_lowercase(self): + future = self.transfer_manager.upload( + self.filename, + self.bucket, + self.key, + {'ChecksumAlgorithm': 'crc32c'}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.PUT_OBJECT, + 'send_filepath': self.filename, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'checksum_config': self._get_expected_upload_checksum_config( + algorithm=awscrt.s3.S3ChecksumAlgorithm.CRC32C + ), + }, ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='PUT', + expected_content_length=len(self.expected_content), + expected_missing_headers=[ + 'Content-MD5', + 'x-amz-sdk-checksum-algorithm', + 'X-Amz-Trailer', + ], + ) + self._assert_subscribers_called(future) + + def test_upload_throws_error_for_unsupported_checksum(self): + with self.assertRaisesRegex( + ValueError, 'ChecksumAlgorithm: UNSUPPORTED not supported' + ): + self.transfer_manager.upload( + self.filename, + self.bucket, + self.key, + {'ChecksumAlgorithm': 'UNSUPPORTED'}, + [self.record_subscriber], + ) + + def test_download(self): future = self.transfer_manager.download( self.bucket, self.key, self.filename, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': mock.ANY, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': None, + 'checksum_config': self._get_expected_download_checksum_config(), + }, + ) # the recv_filepath will be set to a temporary file path with some # random suffix self.assertTrue( @@ -148,42 +382,111 @@ def test_download(self): f'{self.filename}.*', ) ) - self.assertIsNone(callargs_kwargs["send_filepath"]) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, + ) + self._assert_subscribers_called(future) + with open(self.filename, 'rb') as f: + # Check the fake response overwrites the file because of download + self.assertEqual(f.read(), self.expected_download_content) + + def test_download_to_seekable_stream(self): + with open(self.filename, 'wb') as f: + future = self.transfer_manager.download( + self.bucket, self.key, f, {}, [self.record_subscriber] + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.GET_OBJECT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': mock.ANY, + 'checksum_config': self._get_expected_download_checksum_config(), + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, ) - crt_request = callargs_kwargs["request"] - self.assertEqual("GET", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) with open(self.filename, 'rb') as f: # Check the fake response overwrites the file because of download - self.assertEqual(f.read(), b'fake response') + self.assertEqual(f.read(), self.expected_download_content) - def test_delete(self): - self.s3_crt_client.make_request.side_effect = ( - self._simulate_make_request_side_effect + def test_download_to_nonseekable_stream(self): + underlying_stream = io.BytesIO() + nonseekable_stream = NonSeekableWriter(underlying_stream) + future = self.transfer_manager.download( + self.bucket, + self.key, + nonseekable_stream, + {}, + [self.record_subscriber], + ) + future.result() + + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] + self.assertEqual( + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.GET_OBJECT, + 'recv_filepath': None, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + 'on_body': mock.ANY, + 'checksum_config': self._get_expected_download_checksum_config(), + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='GET', + expected_content_length=0, + ) + self._assert_subscribers_called(future) + self.assertEqual( + underlying_stream.getvalue(), self.expected_download_content ) + + def test_delete(self): future = self.transfer_manager.delete( self.bucket, self.key, {}, [self.record_subscriber] ) future.result() - callargs = self.s3_crt_client.make_request.call_args - callargs_kwargs = callargs[1] - self.assertIsNone(callargs_kwargs["send_filepath"]) - self.assertIsNone(callargs_kwargs["recv_filepath"]) + callargs_kwargs = self.s3_crt_client.make_request.call_args[1] self.assertEqual( - callargs_kwargs["type"], awscrt.s3.S3RequestType.DEFAULT + callargs_kwargs, + { + 'request': mock.ANY, + 'type': awscrt.s3.S3RequestType.DEFAULT, + 'on_progress': mock.ANY, + 'on_done': mock.ANY, + }, + ) + self._assert_expected_crt_http_request( + callargs_kwargs["request"], + expected_http_method='DELETE', + expected_content_length=0, ) - crt_request = callargs_kwargs["request"] - self.assertEqual("DELETE", crt_request.method) - self.assertEqual(self.expected_path, crt_request.path) - self.assertEqual(self.expected_host, crt_request.headers.get("host")) self._assert_subscribers_called(future) def test_blocks_when_max_requests_processes_reached(self): + self.s3_crt_client.make_request.return_value = self.s3_request + # We simulate blocking by not invoking the on_done callbacks for + # all of the requests we send. The default side effect invokes all + # callbacks so we need to unset the side effect to avoid on_done from + # being called in the child threads. + self.s3_crt_client.make_request.side_effect = None futures = [] callargs = (self.bucket, self.key, self.filename, {}, []) max_request_processes = 128 # the hard coded max processes diff --git a/tests/integration/test_crt.py b/tests/integration/test_crt.py index 157ae2dc..b3fa7e0f 100644 --- a/tests/integration/test_crt.py +++ b/tests/integration/test_crt.py @@ -11,11 +11,18 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. import glob +import io import os from s3transfer.subscribers import BaseSubscriber from s3transfer.utils import OSUtils -from tests import HAS_CRT, assert_files_equal, requires_crt +from tests import ( + HAS_CRT, + NonSeekableReader, + NonSeekableWriter, + assert_files_equal, + requires_crt, +) from tests.integration import BaseTransferManagerIntegTest if HAS_CRT: @@ -44,13 +51,17 @@ def on_done(self, **kwargs): class TestCRTS3Transfers(BaseTransferManagerIntegTest): """Tests for the high level s3transfer based on CRT implementation.""" + def setUp(self): + super().setUp() + self.s3_key = 's3key.txt' + self.download_path = os.path.join(self.files.rootdir, 'download.txt') + def _create_s3_transfer(self): self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( - self.session + self.session, client_kwargs={'region_name': self.region} ) - credetial_resolver = self.session.get_component('credential_provider') self.s3_crt_client = s3transfer.crt.create_s3_crt_client( - self.session.get_config_variable("region"), credetial_resolver + self.region, self._get_crt_credentials_provider() ) self.record_subscriber = RecordingSubscriber() self.osutil = OSUtils() @@ -58,6 +69,47 @@ def _create_s3_transfer(self): self.s3_crt_client, self.request_serializer ) + def _get_crt_credentials_provider(self): + botocore_credentials = self.session.get_credentials() + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + botocore_credentials + ) + return wrapper.to_crt_credentials_provider() + + def _upload_with_crt_transfer_manager(self, fileobj, key=None): + if key is None: + key = self.s3_key + self.addCleanup(self.delete_object, key) + with self._create_s3_transfer() as transfer: + future = transfer.upload( + fileobj, + self.bucket_name, + key, + subscribers=[self.record_subscriber], + ) + future.result() + + def _download_with_crt_transfer_manager(self, fileobj, key=None): + if key is None: + key = self.s3_key + self.addCleanup(self.delete_object, key) + with self._create_s3_transfer() as transfer: + future = transfer.download( + self.bucket_name, + key, + fileobj, + subscribers=[self.record_subscriber], + ) + future.result() + + def _assert_expected_s3_object(self, key, expected_size=None): + self.assertTrue(self.object_exists(key)) + if expected_size is not None: + response = self.client.head_object( + Bucket=self.bucket_name, Key=key + ) + self.assertEqual(response['ContentLength'], expected_size) + def _assert_has_public_read_acl(self, response): grants = response['Grants'] public_read = [ @@ -176,6 +228,43 @@ def test_upload_file_above_threshold_with_ssec(self): self.assertEqual(response['SSECustomerAlgorithm'], 'AES256') self._assert_subscribers_called(file_size) + def test_upload_seekable_stream(self): + size = 1024 * 1024 + self._upload_with_crt_transfer_manager(io.BytesIO(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_multipart_upload_seekable_stream(self): + size = 20 * 1024 * 1024 + self._upload_with_crt_transfer_manager(io.BytesIO(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_nonseekable_stream(self): + size = 1024 * 1024 + self._upload_with_crt_transfer_manager(NonSeekableReader(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_multipart_upload_nonseekable_stream(self): + size = 20 * 1024 * 1024 + self._upload_with_crt_transfer_manager(NonSeekableReader(b'0' * size)) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_empty_file(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self._upload_with_crt_transfer_manager(filename) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + + def test_upload_empty_stream(self): + size = 0 + self._upload_with_crt_transfer_manager(io.BytesIO(b'')) + self._assert_expected_s3_object(self.s3_key, expected_size=size) + self._assert_subscribers_called(size) + def test_can_send_extra_params_on_download(self): # We're picking the customer provided sse feature # of S3 to test the extra_args functionality of @@ -244,6 +333,65 @@ def test_download_above_threshold(self): file_size = self.osutil.get_file_size(download_path) self._assert_subscribers_called(file_size) + def test_download_seekable_stream(self): + size = 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_multipart_download_seekable_stream(self): + size = 20 * 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_nonseekable_stream(self): + size = 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(NonSeekableWriter(f)) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_multipart_download_nonseekable_stream(self): + size = 20 * 1024 * 1024 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(NonSeekableWriter(f)) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_empty_file(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + self._download_with_crt_transfer_manager(self.download_path) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + + def test_download_empty_stream(self): + size = 0 + filename = self.files.create_file_with_size(self.s3_key, filesize=size) + self.upload_file(filename, self.s3_key) + + with open(self.download_path, 'wb') as f: + self._download_with_crt_transfer_manager(f) + self._assert_subscribers_called(size) + assert_files_equal(filename, self.download_path) + def test_delete(self): transfer = self._create_s3_transfer() filename = self.files.create_file_with_size( @@ -363,3 +511,20 @@ def test_download_cancel(self): possible_matches = glob.glob('%s*' % download_path) self.assertEqual(possible_matches, []) self._assert_subscribers_called() + + def test_exception_translation(self): + # Test that CRT's S3ResponseError translates to botocore error + transfer = self._create_s3_transfer() + download_path = os.path.join( + self.files.rootdir, 'obviously-no-such-key.txt' + ) + with self.assertRaises(self.client.exceptions.NoSuchKey) as cm: + future = transfer.download( + self.bucket_name, + 'obviously-no-such-key.txt', + download_path, + subscribers=[self.record_subscriber], + ) + future.result() + + self.assertEqual(cm.exception.response['Error']['Code'], 'NoSuchKey') diff --git a/tests/unit/test_crt.py b/tests/unit/test_crt.py index b6ad3245..da899289 100644 --- a/tests/unit/test_crt.py +++ b/tests/unit/test_crt.py @@ -10,7 +10,11 @@ # distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -from botocore.credentials import CredentialResolver, ReadOnlyCredentials +import io + +import pytest +from botocore.credentials import Credentials, ReadOnlyCredentials +from botocore.exceptions import ClientError, NoCredentialsError from botocore.session import Session from s3transfer.exceptions import TransferNotDoneError @@ -18,15 +22,73 @@ from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest if HAS_CRT: + import awscrt.auth import awscrt.s3 import s3transfer.crt +requires_crt_pytest = pytest.mark.skipif( + not HAS_CRT, reason="Test requires awscrt to be installed." +) + + +@pytest.fixture +def mock_crt_process_lock(monkeypatch): + # The process lock is cached at the module layer whenever the + # cross process lock is successfully acquired. This patch ensures that + # test cases will start off with no previously cached process lock and + # if a cross process is instantiated/acquired it will be the mock that + # can be used for controlling lock behavior. + monkeypatch.setattr('s3transfer.crt.CRT_S3_PROCESS_LOCK', None) + with mock.patch('awscrt.s3.CrossProcessLock', spec=True) as mock_lock: + yield mock_lock + + +@pytest.fixture +def mock_s3_crt_client(): + with mock.patch('s3transfer.crt.S3Client', spec=True) as mock_client: + yield mock_client + + +@pytest.fixture +def mock_get_recommended_throughput_target_gbps(): + with mock.patch( + 's3transfer.crt.get_recommended_throughput_target_gbps' + ) as mock_get_target_gbps: + yield mock_get_target_gbps + + class CustomFutureException(Exception): pass +@requires_crt_pytest +class TestCRTProcessLock: + def test_acquire_crt_s3_process_lock(self, mock_crt_process_lock): + lock = s3transfer.crt.acquire_crt_s3_process_lock('app-name') + assert lock is s3transfer.crt.CRT_S3_PROCESS_LOCK + assert lock is mock_crt_process_lock.return_value + mock_crt_process_lock.assert_called_once_with('app-name') + mock_crt_process_lock.return_value.acquire.assert_called_once_with() + + def test_unable_to_acquire_lock_returns_none(self, mock_crt_process_lock): + mock_crt_process_lock.return_value.acquire.side_effect = RuntimeError + assert s3transfer.crt.acquire_crt_s3_process_lock('app-name') is None + assert s3transfer.crt.CRT_S3_PROCESS_LOCK is None + mock_crt_process_lock.assert_called_once_with('app-name') + mock_crt_process_lock.return_value.acquire.assert_called_once_with() + + def test_multiple_acquires_return_same_lock(self, mock_crt_process_lock): + lock = s3transfer.crt.acquire_crt_s3_process_lock('app-name') + assert s3transfer.crt.acquire_crt_s3_process_lock('app-name') is lock + assert lock is s3transfer.crt.CRT_S3_PROCESS_LOCK + + # The process lock should have only been instantiated and acquired once + mock_crt_process_lock.assert_called_once_with('app-name') + mock_crt_process_lock.return_value.acquire.assert_called_once_with() + + @requires_crt class TestBotocoreCRTRequestSerializer(unittest.TestCase): def setUp(self): @@ -102,46 +164,131 @@ def test_delete_request(self): self.assertEqual(self.expected_host, crt_request.headers.get("host")) self.assertIsNone(crt_request.headers.get("Authorization")) + def _create_crt_response_error( + self, status_code, body, operation_name=None + ): + return awscrt.s3.S3ResponseError( + code=14343, + name='AWS_ERROR_S3_INVALID_RESPONSE_STATUS', + message='Invalid response status from request', + status_code=status_code, + headers=[ + ('x-amz-request-id', 'QSJHJJZR2EDYD4GQ'), + ( + 'x-amz-id-2', + 'xDbgdKdvYZTjgpOTzm7yNP2JPrOQl+eaQvUkFdOjdJoWkIC643fgHxdsHpUKvVAfjKf5F6otEYA=', + ), + ('Content-Type', 'application/xml'), + ('Transfer-Encoding', 'chunked'), + ('Date', 'Fri, 10 Nov 2023 23:22:47 GMT'), + ('Server', 'AmazonS3'), + ], + body=body, + operation_name=operation_name, + ) + + def test_translate_get_object_404(self): + body = ( + b'\n' + b'NoSuchKey' + b'The specified key does not exist.' + b'obviously-no-such-key.txt' + b'SBJ7ZQY03N1WDW9T' + b'SomeHostId' + ) + crt_exc = self._create_crt_response_error(404, body, 'GetObject') + boto_err = self.request_serializer.translate_crt_exception(crt_exc) + self.assertIsInstance( + boto_err, self.session.create_client('s3').exceptions.NoSuchKey + ) -@requires_crt -class TestCRTCredentialProviderAdapter(unittest.TestCase): - def setUp(self): - self.botocore_credential_provider = mock.Mock(CredentialResolver) - self.access_key = "access_key" - self.secret_key = "secret_key" - self.token = "token" - self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials( - self.access_key, self.secret_key, self.token + def test_translate_head_object_404(self): + # There's no body in a HEAD response, so we can't map it to a modeled S3 exception. + # But it should still map to a botocore ClientError + body = None + crt_exc = self._create_crt_response_error( + 404, body, operation_name='HeadObject' ) + boto_err = self.request_serializer.translate_crt_exception(crt_exc) + self.assertIsInstance(boto_err, ClientError) - def _call_adapter_and_check(self, credentails_provider_adapter): - credentials = credentails_provider_adapter() - self.assertEqual(credentials.access_key_id, self.access_key) - self.assertEqual(credentials.secret_access_key, self.secret_key) - self.assertEqual(credentials.session_token, self.token) + def test_translate_unknown_operation_404(self): + body = None + crt_exc = self._create_crt_response_error(404, body) + boto_err = self.request_serializer.translate_crt_exception(crt_exc) + self.assertIsInstance(boto_err, ClientError) - def test_fetch_crt_credentials_successfully(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) + +@requires_crt_pytest +class TestBotocoreCRTCredentialsWrapper: + @pytest.fixture + def botocore_credentials(self): + return Credentials( + access_key='access_key', secret_key='secret_key', token='token' + ) + + def assert_crt_credentials( + self, + crt_credentials, + expected_access_key='access_key', + expected_secret_key='secret_key', + expected_token='token', + ): + assert crt_credentials.access_key_id == expected_access_key + assert crt_credentials.secret_access_key == expected_secret_key + assert crt_credentials.session_token == expected_token + + def test_fetch_crt_credentials_successfully(self, botocore_credentials): + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + botocore_credentials + ) + crt_credentials = wrapper() + self.assert_crt_credentials(crt_credentials) + + def test_wrapper_does_not_cache_frozen_credentials(self): + mock_credentials = mock.Mock(Credentials) + mock_credentials.get_frozen_credentials.side_effect = [ + ReadOnlyCredentials('access_key_1', 'secret_key_1', 'token_1'), + ReadOnlyCredentials('access_key_2', 'secret_key_2', 'token_2'), + ] + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + mock_credentials + ) + + crt_credentials_1 = wrapper() + self.assert_crt_credentials( + crt_credentials_1, + expected_access_key='access_key_1', + expected_secret_key='secret_key_1', + expected_token='token_1', + ) + + crt_credentials_2 = wrapper() + self.assert_crt_credentials( + crt_credentials_2, + expected_access_key='access_key_2', + expected_secret_key='secret_key_2', + expected_token='token_2', ) - self._call_adapter_and_check(credentails_provider_adapter) - def test_load_credentials_once(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) + assert mock_credentials.get_frozen_credentials.call_count == 2 + + def test_raises_error_when_resolved_credentials_is_none(self): + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper(None) + with pytest.raises(NoCredentialsError): + wrapper() + + def test_to_crt_credentials_provider(self, botocore_credentials): + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + botocore_credentials ) - called_times = 5 - for i in range(called_times): - self._call_adapter_and_check(credentails_provider_adapter) - # Assert that the load_credentails of botocore credential provider - # will only be called once - self.assertEqual( - self.botocore_credential_provider.load_credentials.call_count, 1 + crt_credentials_provider = wrapper.to_crt_credentials_provider() + assert isinstance( + crt_credentials_provider, awscrt.auth.AwsCredentialsProvider ) + get_credentials_future = crt_credentials_provider.get_credentials() + crt_credentials = get_credentials_future.result() + self.assert_crt_credentials(crt_credentials) @requires_crt @@ -171,3 +318,47 @@ def test_set_exception_can_override_previous_exception(self): self.future.set_exception(CustomFutureException()) with self.assertRaises(CustomFutureException): self.future.result() + + +@requires_crt +class TestOnBodyFileObjWriter(unittest.TestCase): + def test_call(self): + fileobj = io.BytesIO() + writer = s3transfer.crt.OnBodyFileObjWriter(fileobj) + writer(chunk=b'content') + self.assertEqual(fileobj.getvalue(), b'content') + + +@requires_crt_pytest +class TestCreateS3CRTClient: + @pytest.mark.parametrize( + 'provided_bytes_per_sec,recommended_gbps,expected_gbps', + [ + (None, 100.0, 100.0), + (None, None, 10.0), + # NOTE: create_s3_crt_client() accepts target throughput as bytes + # per second and it is converted to gigabits per second for the + # CRT client instantiation. + (1_000_000_000, None, 8.0), + (1_000_000_000, 100.0, 8.0), + ], + ) + def test_target_throughput( + self, + provided_bytes_per_sec, + recommended_gbps, + expected_gbps, + mock_s3_crt_client, + mock_get_recommended_throughput_target_gbps, + ): + mock_get_recommended_throughput_target_gbps.return_value = ( + recommended_gbps + ) + s3transfer.crt.create_s3_crt_client( + 'us-west-2', + target_throughput=provided_bytes_per_sec, + ) + assert ( + mock_s3_crt_client.call_args[1]['throughput_target_gbps'] + == expected_gbps + )