From a86f7cf04b2e4dc89a868f5c198b9c057aada7de Mon Sep 17 00:00:00 2001 From: kyleknap Date: Tue, 21 Nov 2023 12:45:32 -0800 Subject: [PATCH] Update interface for providing credentials to CRT The two main changes are: * Update the CRT S3 client factory to accept a CRT credential provider instead of a botocore credential provider. Under the hood, the S3 client only accepts a CRT credential provider. So, this provides more flexibility in being able to provide other CRT credenital providers directly instead of being forced to the botocore credential provider interface * Update the botocore to CRT credentials adapter interface to only accept botocore credential objects instead of credential providers. In general, the credentials object is more accessible than the provider; it can be retrieved at the session level and is what is passed into clients. Also, this change avoids a limitation where the load_credentials() method on the credential provider cannot be called more than twice for some configurations (e.g. assume role from profile), which can be an issue if you create both a botocore client and CRT S3 client. --- s3transfer/crt.py | 63 +++++++++------------ tests/integration/test_crt.py | 10 +++- tests/unit/test_crt.py | 101 ++++++++++++++++++++++------------ 3 files changed, 100 insertions(+), 74 deletions(-) diff --git a/s3transfer/crt.py b/s3transfer/crt.py index 465d58d3..38923c4e 100644 --- a/s3transfer/crt.py +++ b/s3transfer/crt.py @@ -71,33 +71,9 @@ def acquire_crt_s3_process_lock(name): return CRT_S3_PROCESS_LOCK -class CRTCredentialProviderAdapter: - def __init__(self, botocore_credential_provider): - self._botocore_credential_provider = botocore_credential_provider - self._loaded_credentials = None - self._lock = threading.Lock() - - def __call__(self): - credentials = self._get_credentials().get_frozen_credentials() - return AwsCredentials( - credentials.access_key, credentials.secret_key, credentials.token - ) - - def _get_credentials(self): - with self._lock: - if self._loaded_credentials is None: - loaded_creds = ( - self._botocore_credential_provider.load_credentials() - ) - if loaded_creds is None: - raise NoCredentialsError() - self._loaded_credentials = loaded_creds - return self._loaded_credentials - - def create_s3_crt_client( region, - botocore_credential_provider=None, + crt_credentials_provider=None, num_threads=None, target_throughput=None, part_size=8 * MB, @@ -108,10 +84,10 @@ def create_s3_crt_client( :type region: str :param region: The region used for signing - :type botocore_credential_provider: - Optional[botocore.credentials.CredentialResolver] - :param botocore_credential_provider: Provide credentials for CRT - to sign the request if not set, the request will not be signed + :type crt_credentials_provider: + Optional[awscrt.auth.AwsCredentialsProvider] + :param crt_credentials_provider: CRT AWS credentials provider + to use to sign requests. If not set, requests will not be signed. :type num_threads: Optional[int] :param num_threads: Number of worker threads generated. Default @@ -151,7 +127,6 @@ def create_s3_crt_client( event_loop_group = EventLoopGroup(num_threads) host_resolver = DefaultHostResolver(event_loop_group) bootstrap = ClientBootstrap(event_loop_group, host_resolver) - provider = None tls_connection_options = None tls_mode = ( @@ -167,20 +142,13 @@ def create_s3_crt_client( tls_ctx_options.verify_peer = False client_tls_option = ClientTlsContext(tls_ctx_options) tls_connection_options = client_tls_option.new_connection_options() - if botocore_credential_provider: - credentails_provider_adapter = CRTCredentialProviderAdapter( - botocore_credential_provider - ) - provider = AwsCredentialsProvider.new_delegate( - credentails_provider_adapter - ) target_gbps = _get_crt_throughput_target_gbps( provided_throughput_target_bytes=target_throughput ) return S3Client( bootstrap=bootstrap, region=region, - credential_provider=provider, + credential_provider=crt_credentials_provider, part_size=part_size, tls_mode=tls_mode, tls_connection_options=tls_connection_options, @@ -575,6 +543,25 @@ def stream(self, amt=1024, decode_content=None): yield chunk +class BotocoreCRTCredentialsWrapper: + def __init__(self, resolved_botocore_credentials): + self._resolved_credentials = resolved_botocore_credentials + + def __call__(self): + credentials = self._get_credentials().get_frozen_credentials() + return AwsCredentials( + credentials.access_key, credentials.secret_key, credentials.token + ) + + def to_crt_credentials_provider(self): + return AwsCredentialsProvider.new_delegate(self) + + def _get_credentials(self): + if self._resolved_credentials is None: + raise NoCredentialsError() + return self._resolved_credentials + + class CRTTransferCoordinator: """A helper class for managing CRTTransferFuture""" diff --git a/tests/integration/test_crt.py b/tests/integration/test_crt.py index 7881fa63..7f16d85e 100644 --- a/tests/integration/test_crt.py +++ b/tests/integration/test_crt.py @@ -60,9 +60,8 @@ def _create_s3_transfer(self): self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer( self.session, client_kwargs={'region_name': self.region} ) - credetial_resolver = self.session.get_component('credential_provider') self.s3_crt_client = s3transfer.crt.create_s3_crt_client( - self.region, credetial_resolver + self.region, self._get_crt_credentials_provider() ) self.record_subscriber = RecordingSubscriber() self.osutil = OSUtils() @@ -70,6 +69,13 @@ def _create_s3_transfer(self): self.s3_crt_client, self.request_serializer ) + def _get_crt_credentials_provider(self): + botocore_credentials = self.session.get_credentials() + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + botocore_credentials + ) + return wrapper.to_crt_credentials_provider() + def _upload_with_crt_transfer_manager(self, fileobj, key=None): if key is None: key = self.s3_key diff --git a/tests/unit/test_crt.py b/tests/unit/test_crt.py index 5d3dd40d..6442301a 100644 --- a/tests/unit/test_crt.py +++ b/tests/unit/test_crt.py @@ -13,7 +13,8 @@ import io import pytest -from botocore.credentials import CredentialResolver, ReadOnlyCredentials +from botocore.credentials import Credentials, ReadOnlyCredentials +from botocore.exceptions import NoCredentialsError from botocore.session import Session from s3transfer.exceptions import TransferNotDoneError @@ -21,6 +22,7 @@ from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest if HAS_CRT: + import awscrt.auth import awscrt.s3 import s3transfer.crt @@ -163,45 +165,76 @@ def test_delete_request(self): self.assertIsNone(crt_request.headers.get("Authorization")) -@requires_crt -class TestCRTCredentialProviderAdapter(unittest.TestCase): - def setUp(self): - self.botocore_credential_provider = mock.Mock(CredentialResolver) - self.access_key = "access_key" - self.secret_key = "secret_key" - self.token = "token" - self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials( - self.access_key, self.secret_key, self.token +@requires_crt_pytest +class TestBotocoreCRTCredentialsWrapper: + @pytest.fixture + def botocore_credentials(self): + return Credentials( + access_key='access_key', secret_key='secret_key', token='token' ) - def _call_adapter_and_check(self, credentails_provider_adapter): - credentials = credentails_provider_adapter() - self.assertEqual(credentials.access_key_id, self.access_key) - self.assertEqual(credentials.secret_access_key, self.secret_key) - self.assertEqual(credentials.session_token, self.token) - - def test_fetch_crt_credentials_successfully(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) + def assert_crt_credentials( + self, + crt_credentials, + expected_access_key='access_key', + expected_secret_key='secret_key', + expected_token='token', + ): + assert crt_credentials.access_key_id == expected_access_key + assert crt_credentials.secret_access_key == expected_secret_key + assert crt_credentials.session_token == expected_token + + def test_fetch_crt_credentials_successfully(self, botocore_credentials): + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + botocore_credentials + ) + crt_credentials = wrapper() + self.assert_crt_credentials(crt_credentials) + + def test_wrapper_does_not_cache_frozen_credentials(self): + mock_credentials = mock.Mock(Credentials) + mock_credentials.get_frozen_credentials.side_effect = [ + ReadOnlyCredentials('access_key_1', 'secret_key_1', 'token_1'), + ReadOnlyCredentials('access_key_2', 'secret_key_2', 'token_2'), + ] + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + mock_credentials ) - self._call_adapter_and_check(credentails_provider_adapter) - def test_load_credentials_once(self): - credentails_provider_adapter = ( - s3transfer.crt.CRTCredentialProviderAdapter( - self.botocore_credential_provider - ) + crt_credentials_1 = wrapper() + self.assert_crt_credentials( + crt_credentials_1, + expected_access_key='access_key_1', + expected_secret_key='secret_key_1', + expected_token='token_1', + ) + + crt_credentials_2 = wrapper() + self.assert_crt_credentials( + crt_credentials_2, + expected_access_key='access_key_2', + expected_secret_key='secret_key_2', + expected_token='token_2', + ) + + assert mock_credentials.get_frozen_credentials.call_count == 2 + + def test_raises_error_when_resolved_credentials_is_none(self): + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper(None) + with pytest.raises(NoCredentialsError): + wrapper() + + def test_to_crt_credentials_provider(self, botocore_credentials): + wrapper = s3transfer.crt.BotocoreCRTCredentialsWrapper( + botocore_credentials ) - called_times = 5 - for i in range(called_times): - self._call_adapter_and_check(credentails_provider_adapter) - # Assert that the load_credentails of botocore credential provider - # will only be called once - self.assertEqual( - self.botocore_credential_provider.load_credentials.call_count, 1 + crt_credentials_provider = wrapper.to_crt_credentials_provider() + assert isinstance( + crt_credentials_provider, awscrt.auth.AwsCredentialsProvider ) + get_credentials_future = crt_credentials_provider.get_credentials() + crt_credentials = get_credentials_future.result() + self.assert_crt_credentials(crt_credentials) @requires_crt