diff --git a/awscli/customizations/s3/factory.py b/awscli/customizations/s3/factory.py index bb38f4650a63..3d9b3c9282fc 100644 --- a/awscli/customizations/s3/factory.py +++ b/awscli/customizations/s3/factory.py @@ -18,7 +18,8 @@ from s3transfer.manager import TransferManager from s3transfer.crt import ( acquire_crt_s3_process_lock, create_s3_crt_client, - BotocoreCRTRequestSerializer, CRTTransferManager + BotocoreCRTRequestSerializer, CRTTransferManager, + CRTCredentialsWrapper ) from awscli.compat import urlparse @@ -131,9 +132,9 @@ def _create_crt_client(self, params, runtime_config): if multipart_chunksize: create_crt_client_kwargs['part_size'] = multipart_chunksize if params.get('sign_request', True): + crt_credentials_provider = self._get_crt_credentials_provider() create_crt_client_kwargs[ - 'botocore_credential_provider'] = self._session.get_component( - 'credential_provider') + 'crt_credentials_provider'] = crt_credentials_provider return create_s3_crt_client(**create_crt_client_kwargs) @@ -163,6 +164,11 @@ def _create_classic_transfer_manager(self, params, runtime_config, ) return TransferManager(client, transfer_config) + def _get_crt_credentials_provider(self): + botocore_credentials = self._session.get_credentials() + wrapper = CRTCredentialsWrapper(botocore_credentials) + return wrapper.to_crt_credentials_provider() + def _resolve_region(self, params): region = params.get('region') if region is None: diff --git a/awscli/s3transfer/crt.py b/awscli/s3transfer/crt.py index 465d58d38895..f2b579993dbf 100644 --- a/awscli/s3transfer/crt.py +++ b/awscli/s3transfer/crt.py @@ -71,33 +71,9 @@ def acquire_crt_s3_process_lock(name): return CRT_S3_PROCESS_LOCK -class CRTCredentialProviderAdapter: - def __init__(self, botocore_credential_provider): - self._botocore_credential_provider = botocore_credential_provider - self._loaded_credentials = None - self._lock = threading.Lock() - - def __call__(self): - credentials = self._get_credentials().get_frozen_credentials() - return AwsCredentials( - credentials.access_key, credentials.secret_key, credentials.token - ) - - def _get_credentials(self): - with self._lock: - if self._loaded_credentials is None: - loaded_creds = ( - self._botocore_credential_provider.load_credentials() - ) - if loaded_creds is None: - raise NoCredentialsError() - self._loaded_credentials = loaded_creds - return self._loaded_credentials - - def create_s3_crt_client( region, - botocore_credential_provider=None, + crt_credentials_provider=None, num_threads=None, target_throughput=None, part_size=8 * MB, @@ -108,10 +84,10 @@ def create_s3_crt_client( :type region: str :param region: The region used for signing - :type botocore_credential_provider: - Optional[botocore.credentials.CredentialResolver] - :param botocore_credential_provider: Provide credentials for CRT - to sign the request if not set, the request will not be signed + :type crt_credentials_provider: + Optional[awscrt.auth.AwsCredentialsProvider] + :param crt_credentials_provider: CRT AWS credentials provider + to use to sign requests. If not set, requests will not be signed. :type num_threads: Optional[int] :param num_threads: Number of worker threads generated. Default @@ -151,7 +127,6 @@ def create_s3_crt_client( event_loop_group = EventLoopGroup(num_threads) host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) - provider = None tls_connection_options = None tls_mode = ( @@ -167,20 +142,13 @@ def create_s3_crt_client( tls_ctx_options.verify_peer = False client_tls_option = ClientTlsContext(tls_ctx_options) tls_connection_options = client_tls_option.new_connection_options() - if botocore_credential_provider: - credentails_provider_adapter = CRTCredentialProviderAdapter( - botocore_credential_provider - ) - provider = AwsCredentialsProvider.new_delegate( - credentails_provider_adapter - ) target_gbps = _get_crt_throughput_target_gbps( provided_throughput_target_bytes=target_throughput ) return S3Client( bootstrap=bootstrap, region=region, - credential_provider=provider, + credential_provider=crt_credentials_provider, part_size=part_size, tls_mode=tls_mode, tls_connection_options=tls_connection_options, @@ -575,6 +543,25 @@ def stream(self, amt=1024, decode_content=None): yield chunk +class CRTCredentialsWrapper: + def __init__(self, resolved_botocore_credentials): + self._resolved_credentials = resolved_botocore_credentials + + def __call__(self): + credentials = self._get_credentials().get_frozen_credentials() + return AwsCredentials( + credentials.access_key, credentials.secret_key, credentials.token + ) + + def to_crt_credentials_provider(self): + return AwsCredentialsProvider.new_delegate(self) + + def _get_credentials(self): + if self._resolved_credentials is None: + raise NoCredentialsError() + return self._resolved_credentials + + class CRTTransferCoordinator: """A helper class for managing CRTTransferFuture""" diff --git a/tests/integration/s3transfer/test_crt.py b/tests/integration/s3transfer/test_crt.py index 9580aecf50f8..a7ca2b5a3ddd 100644 --- a/tests/integration/s3transfer/test_crt.py +++ b/tests/integration/s3transfer/test_crt.py @@ -60,9 +60,8 @@ def _create_s3_transfer(self): self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( self.session, client_kwargs={'region_name': self.region} ) - credetial_resolver = self.session.get_component('credential_provider') self.s3_crt_client = s3transfer.crt.create_s3_crt_client( - self.region, credetial_resolver + self.region, self._get_crt_credentials_provider() ) self.record_subscriber = RecordingSubscriber() self.osutil = OSUtils() @@ -70,6 +69,11 @@ def _create_s3_transfer(self): self.s3_crt_client, self.request_serializer ) + def _get_crt_credentials_provider(self): + botocore_credentials = self.session.get_credentials() + wrapper = s3transfer.crt.CRTCredentialsWrapper(botocore_credentials) + return wrapper.to_crt_credentials_provider() + def _upload_with_crt_transfer_manager(self, fileobj, key=None): if key is None: key = self.s3_key diff --git a/tests/unit/customizations/s3/test_factory.py b/tests/unit/customizations/s3/test_factory.py index f2c261068f35..457b6b74d607 100644 --- a/tests/unit/customizations/s3/test_factory.py +++ b/tests/unit/customizations/s3/test_factory.py @@ -16,6 +16,7 @@ from awscrt.s3 import S3RequestTlsMode from botocore.session import Session from botocore.config import Config +from botocore.credentials import Credentials from botocore.httpsession import DEFAULT_CA_BUNDLE from s3transfer.manager import TransferManager import s3transfer.crt @@ -324,20 +325,31 @@ def test_can_disable_tls_using_endpoint_scheme_for_crt_manager( self.assert_tls_disabled_for_crt_client(mock_crt_client) @mock.patch('s3transfer.crt.S3Client') - def test_uses_botocore_credential_provider_for_crt_manager( + def test_uses_botocore_credentials_for_crt_manager( self, mock_crt_client): + credentials = Credentials('access_key', 'secret_key', 'token') + self.session.get_credentials.return_value = credentials self.runtime_config = self.get_runtime_config( preferred_transfer_client='crt') transfer_manager = self.factory.create_transfer_manager( self.params, self.runtime_config) self.assert_is_crt_manager(transfer_manager) - self.session.get_component.assert_called_with('credential_provider') - self.assertIsNotNone( - mock_crt_client.call_args[1]['credential_provider'] - ) + self.session.get_credentials.assert_called_with() + crt_credential_provider = mock_crt_client.call_args[1][ + 'credential_provider' + ] + self.assertIsNotNone(crt_credential_provider) + + # Ensure the credentials returned by the CRT credential provider + # match the session's credentials + future = crt_credential_provider.get_credentials() + crt_credentials = future.result() + assert crt_credentials.access_key_id == 'access_key' + assert crt_credentials.secret_access_key == 'secret_key' + assert crt_credentials.session_token == 'token' @mock.patch('s3transfer.crt.S3Client') - def test_disable_botocore_credential_provider_for_crt_manager( + def test_disable_botocore_credentials_for_crt_manager( self, mock_crt_client): self.runtime_config = self.get_runtime_config( preferred_transfer_client='crt') @@ -345,7 +357,7 @@ def test_disable_botocore_credential_provider_for_crt_manager( transfer_manager = self.factory.create_transfer_manager( self.params, self.runtime_config) self.assert_is_crt_manager(transfer_manager) - self.session.get_component.assert_not_called() + self.session.get_credentials.assert_not_called() self.assertIsNone( mock_crt_client.call_args[1]['credential_provider'] ) diff --git a/tests/unit/s3transfer/test_crt.py b/tests/unit/s3transfer/test_crt.py index 5d3dd40d803e..e9d8cda0d2b1 100644 --- a/tests/unit/s3transfer/test_crt.py +++ b/tests/unit/s3transfer/test_crt.py @@ -13,7 +13,8 @@ import io import pytest -from botocore.credentials import CredentialResolver, ReadOnlyCredentials +from botocore.credentials import Credentials, ReadOnlyCredentials +from botocore.exceptions import NoCredentialsError from botocore.session import Session from s3transfer.exceptions import TransferNotDoneError @@ -21,6 +22,7 @@ from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest if HAS_CRT: + import awscrt.auth import awscrt.s3 import s3transfer.crt @@ -163,45 +165,70 @@ def test_delete_request(self): self.assertIsNone(crt_request.headers.get("Authorization")) -@requires_crt -class TestCRTCredentialProviderAdapter(unittest.TestCase): - def setUp(self): - self.botocore_credential_provider = mock.Mock(CredentialResolver) - self.access_key = "access_key" - self.secret_key = "secret_key" - self.token = "token" - self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials( - self.access_key, self.secret_key, self.token +@requires_crt_pytest +class TestCRTCredentialsWrapper: + @pytest.fixture + def botocore_credentials(self): + return Credentials( + access_key='access_key', secret_key='secret_key', token='token' ) - def _call_adapter_and_check(self, credentails_provider_adapter): - credentials = credentails_provider_adapter() - self.assertEqual(credentials.access_key_id, self.access_key) - self.assertEqual(credentials.secret_access_key, self.secret_key) - self.assertEqual(credentials.session_token, self.token) - - def test_fetch_crt_credentials_successfully(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) + def assert_crt_credentials( + self, + crt_credentials, + expected_access_key='access_key', + expected_secret_key='secret_key', + expected_token='token', + ): + assert crt_credentials.access_key_id == expected_access_key + assert crt_credentials.secret_access_key == expected_secret_key + assert crt_credentials.session_token == expected_token + + def test_fetch_crt_credentials_successfully(self, botocore_credentials): + wrapper = s3transfer.crt.CRTCredentialsWrapper(botocore_credentials) + crt_credentials = wrapper() + self.assert_crt_credentials(crt_credentials) + + def test_wrapper_does_not_cache_frozen_credentials(self): + mock_credentials = mock.Mock(Credentials) + mock_credentials.get_frozen_credentials.side_effect = [ + ReadOnlyCredentials('access_key_1', 'secret_key_1', 'token_1'), + ReadOnlyCredentials('access_key_2', 'secret_key_2', 'token_2'), + ] + wrapper = s3transfer.crt.CRTCredentialsWrapper(mock_credentials) + + crt_credentials_1 = wrapper() + self.assert_crt_credentials( + crt_credentials_1, + expected_access_key='access_key_1', + expected_secret_key='secret_key_1', + expected_token='token_1' ) - self._call_adapter_and_check(credentails_provider_adapter) - def test_load_credentials_once(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) + crt_credentials_2 = wrapper() + self.assert_crt_credentials( + crt_credentials_2, + expected_access_key='access_key_2', + expected_secret_key='secret_key_2', + expected_token='token_2' ) - called_times = 5 - for i in range(called_times): - self._call_adapter_and_check(credentails_provider_adapter) - # Assert that the load_credentails of botocore credential provider - # will only be called once - self.assertEqual( - self.botocore_credential_provider.load_credentials.call_count, 1 + + assert mock_credentials.get_frozen_credentials.call_count == 2 + + def test_raises_error_when_resolved_credentials_is_none(self): + wrapper = s3transfer.crt.CRTCredentialsWrapper(None) + with pytest.raises(NoCredentialsError): + wrapper() + + def test_to_crt_credentials_provider(self, botocore_credentials): + wrapper = s3transfer.crt.CRTCredentialsWrapper(botocore_credentials) + crt_credentials_provider = wrapper.to_crt_credentials_provider() + assert isinstance( + crt_credentials_provider, awscrt.auth.AwsCredentialsProvider ) + get_credentials_future = crt_credentials_provider.get_credentials() + crt_credentials = get_credentials_future.result() + self.assert_crt_credentials(crt_credentials) @requires_crt