Skip to content

Commit

Permalink
POC for using botocore credentials directly
Browse files Browse the repository at this point in the history
  • Loading branch information
kyleknap committed Nov 16, 2023
1 parent fd7c9ba commit 12cd57d
Show file tree
Hide file tree
Showing 5 changed files with 120 additions and 84 deletions.
12 changes: 9 additions & 3 deletions awscli/customizations/s3/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from s3transfer.manager import TransferManager
from s3transfer.crt import (
acquire_crt_s3_process_lock, create_s3_crt_client,
BotocoreCRTRequestSerializer, CRTTransferManager
BotocoreCRTRequestSerializer, CRTTransferManager,
CRTCredentialsWrapper
)

from awscli.compat import urlparse
Expand Down Expand Up @@ -131,9 +132,9 @@ def _create_crt_client(self, params, runtime_config):
if multipart_chunksize:
create_crt_client_kwargs['part_size'] = multipart_chunksize
if params.get('sign_request', True):
crt_credentials_provider = self._get_crt_credentials_provider()
create_crt_client_kwargs[
'botocore_credential_provider'] = self._session.get_component(
'credential_provider')
'crt_credentials_provider'] = crt_credentials_provider

return create_s3_crt_client(**create_crt_client_kwargs)

Expand Down Expand Up @@ -163,6 +164,11 @@ def _create_classic_transfer_manager(self, params, runtime_config,
)
return TransferManager(client, transfer_config)

def _get_crt_credentials_provider(self):
botocore_credentials = self._session.get_credentials()
wrapper = CRTCredentialsWrapper(botocore_credentials)
return wrapper.to_crt_credentials_provider()

def _resolve_region(self, params):
region = params.get('region')
if region is None:
Expand Down
63 changes: 25 additions & 38 deletions awscli/s3transfer/crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,33 +71,9 @@ def acquire_crt_s3_process_lock(name):
return CRT_S3_PROCESS_LOCK


class CRTCredentialProviderAdapter:
def __init__(self, botocore_credential_provider):
self._botocore_credential_provider = botocore_credential_provider
self._loaded_credentials = None
self._lock = threading.Lock()

def __call__(self):
credentials = self._get_credentials().get_frozen_credentials()
return AwsCredentials(
credentials.access_key, credentials.secret_key, credentials.token
)

def _get_credentials(self):
with self._lock:
if self._loaded_credentials is None:
loaded_creds = (
self._botocore_credential_provider.load_credentials()
)
if loaded_creds is None:
raise NoCredentialsError()
self._loaded_credentials = loaded_creds
return self._loaded_credentials


def create_s3_crt_client(
region,
botocore_credential_provider=None,
crt_credentials_provider=None,
num_threads=None,
target_throughput=None,
part_size=8 * MB,
Expand All @@ -108,10 +84,10 @@ def create_s3_crt_client(
:type region: str
:param region: The region used for signing
:type botocore_credential_provider:
Optional[botocore.credentials.CredentialResolver]
:param botocore_credential_provider: Provide credentials for CRT
to sign the request if not set, the request will not be signed
:type crt_credentials_provider:
Optional[awscrt.auth.AwsCredentialsProvider]
:param crt_credentials_provider: CRT AWS credentials provider
to use to sign requests. If not set, requests will not be signed.
:type num_threads: Optional[int]
:param num_threads: Number of worker threads generated. Default
Expand Down Expand Up @@ -151,7 +127,6 @@ def create_s3_crt_client(
event_loop_group = EventLoopGroup(num_threads)
host_resolver = DefaultHostResolver(event_loop_group)
bootstrap = ClientBootstrap(event_loop_group, host_resolver)
provider = None
tls_connection_options = None

tls_mode = (
Expand All @@ -167,20 +142,13 @@ def create_s3_crt_client(
tls_ctx_options.verify_peer = False
client_tls_option = ClientTlsContext(tls_ctx_options)
tls_connection_options = client_tls_option.new_connection_options()
if botocore_credential_provider:
credentails_provider_adapter = CRTCredentialProviderAdapter(
botocore_credential_provider
)
provider = AwsCredentialsProvider.new_delegate(
credentails_provider_adapter
)
target_gbps = _get_crt_throughput_target_gbps(
provided_throughput_target_bytes=target_throughput
)
return S3Client(
bootstrap=bootstrap,
region=region,
credential_provider=provider,
credential_provider=crt_credentials_provider,
part_size=part_size,
tls_mode=tls_mode,
tls_connection_options=tls_connection_options,
Expand Down Expand Up @@ -575,6 +543,25 @@ def stream(self, amt=1024, decode_content=None):
yield chunk


class CRTCredentialsWrapper:
def __init__(self, resolved_botocore_credentials):
self._resolved_credentials = resolved_botocore_credentials

def __call__(self):
credentials = self._get_credentials().get_frozen_credentials()
return AwsCredentials(
credentials.access_key, credentials.secret_key, credentials.token
)

def to_crt_credentials_provider(self):
return AwsCredentialsProvider.new_delegate(self)

def _get_credentials(self):
if self._resolved_credentials is None:
raise NoCredentialsError()
return self._resolved_credentials


class CRTTransferCoordinator:
"""A helper class for managing CRTTransferFuture"""

Expand Down
8 changes: 6 additions & 2 deletions tests/integration/s3transfer/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,16 +60,20 @@ def _create_s3_transfer(self):
self.request_serializer = s3transfer.crt.BotocoreCRTRequestSerializer(
self.session, client_kwargs={'region_name': self.region}
)
credetial_resolver = self.session.get_component('credential_provider')
self.s3_crt_client = s3transfer.crt.create_s3_crt_client(
self.region, credetial_resolver
self.region, self._get_crt_credentials_provider()
)
self.record_subscriber = RecordingSubscriber()
self.osutil = OSUtils()
return s3transfer.crt.CRTTransferManager(
self.s3_crt_client, self.request_serializer
)

def _get_crt_credentials_provider(self):
botocore_credentials = self.session.get_credentials()
wrapper = s3transfer.crt.CRTCredentialsWrapper(botocore_credentials)
return wrapper.to_crt_credentials_provider()

def _upload_with_crt_transfer_manager(self, fileobj, key=None):
if key is None:
key = self.s3_key
Expand Down
26 changes: 19 additions & 7 deletions tests/unit/customizations/s3/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from awscrt.s3 import S3RequestTlsMode
from botocore.session import Session
from botocore.config import Config
from botocore.credentials import Credentials
from botocore.httpsession import DEFAULT_CA_BUNDLE
from s3transfer.manager import TransferManager
import s3transfer.crt
Expand Down Expand Up @@ -324,28 +325,39 @@ def test_can_disable_tls_using_endpoint_scheme_for_crt_manager(
self.assert_tls_disabled_for_crt_client(mock_crt_client)

@mock.patch('s3transfer.crt.S3Client')
def test_uses_botocore_credential_provider_for_crt_manager(
def test_uses_botocore_credentials_for_crt_manager(
self, mock_crt_client):
credentials = Credentials('access_key', 'secret_key', 'token')
self.session.get_credentials.return_value = credentials
self.runtime_config = self.get_runtime_config(
preferred_transfer_client='crt')
transfer_manager = self.factory.create_transfer_manager(
self.params, self.runtime_config)
self.assert_is_crt_manager(transfer_manager)
self.session.get_component.assert_called_with('credential_provider')
self.assertIsNotNone(
mock_crt_client.call_args[1]['credential_provider']
)
self.session.get_credentials.assert_called_with()
crt_credential_provider = mock_crt_client.call_args[1][
'credential_provider'
]
self.assertIsNotNone(crt_credential_provider)

# Ensure the credentials returned by the CRT credential provider
# match the session's credentials
future = crt_credential_provider.get_credentials()
crt_credentials = future.result()
assert crt_credentials.access_key_id == 'access_key'
assert crt_credentials.secret_access_key == 'secret_key'
assert crt_credentials.session_token == 'token'

@mock.patch('s3transfer.crt.S3Client')
def test_disable_botocore_credential_provider_for_crt_manager(
def test_disable_botocore_credentials_for_crt_manager(
self, mock_crt_client):
self.runtime_config = self.get_runtime_config(
preferred_transfer_client='crt')
self.params['sign_request'] = False
transfer_manager = self.factory.create_transfer_manager(
self.params, self.runtime_config)
self.assert_is_crt_manager(transfer_manager)
self.session.get_component.assert_not_called()
self.session.get_credentials.assert_not_called()
self.assertIsNone(
mock_crt_client.call_args[1]['credential_provider']
)
Expand Down
95 changes: 61 additions & 34 deletions tests/unit/s3transfer/test_crt.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,16 @@
import io

import pytest
from botocore.credentials import CredentialResolver, ReadOnlyCredentials
from botocore.credentials import Credentials, ReadOnlyCredentials
from botocore.exceptions import NoCredentialsError
from botocore.session import Session

from s3transfer.exceptions import TransferNotDoneError
from s3transfer.utils import CallArgs
from tests import HAS_CRT, FileCreator, mock, requires_crt, unittest

if HAS_CRT:
import awscrt.auth
import awscrt.s3

import s3transfer.crt
Expand Down Expand Up @@ -163,45 +165,70 @@ def test_delete_request(self):
self.assertIsNone(crt_request.headers.get("Authorization"))


@requires_crt
class TestCRTCredentialProviderAdapter(unittest.TestCase):
def setUp(self):
self.botocore_credential_provider = mock.Mock(CredentialResolver)
self.access_key = "access_key"
self.secret_key = "secret_key"
self.token = "token"
self.botocore_credential_provider.load_credentials.return_value.get_frozen_credentials.return_value = ReadOnlyCredentials(
self.access_key, self.secret_key, self.token
@requires_crt_pytest
class TestCRTCredentialsWrapper:
@pytest.fixture
def botocore_credentials(self):
return Credentials(
access_key='access_key', secret_key='secret_key', token='token'
)

def _call_adapter_and_check(self, credentails_provider_adapter):
credentials = credentails_provider_adapter()
self.assertEqual(credentials.access_key_id, self.access_key)
self.assertEqual(credentials.secret_access_key, self.secret_key)
self.assertEqual(credentials.session_token, self.token)

def test_fetch_crt_credentials_successfully(self):
credentails_provider_adapter = (
s3transfer.crt.CRTCredentialProviderAdapter(
self.botocore_credential_provider
)
def assert_crt_credentials(
self,
crt_credentials,
expected_access_key='access_key',
expected_secret_key='secret_key',
expected_token='token',
):
assert crt_credentials.access_key_id == expected_access_key
assert crt_credentials.secret_access_key == expected_secret_key
assert crt_credentials.session_token == expected_token

def test_fetch_crt_credentials_successfully(self, botocore_credentials):
wrapper = s3transfer.crt.CRTCredentialsWrapper(botocore_credentials)
crt_credentials = wrapper()
self.assert_crt_credentials(crt_credentials)

def test_wrapper_does_not_cache_frozen_credentials(self):
mock_credentials = mock.Mock(Credentials)
mock_credentials.get_frozen_credentials.side_effect = [
ReadOnlyCredentials('access_key_1', 'secret_key_1', 'token_1'),
ReadOnlyCredentials('access_key_2', 'secret_key_2', 'token_2'),
]
wrapper = s3transfer.crt.CRTCredentialsWrapper(mock_credentials)

crt_credentials_1 = wrapper()
self.assert_crt_credentials(
crt_credentials_1,
expected_access_key='access_key_1',
expected_secret_key='secret_key_1',
expected_token='token_1'
)
self._call_adapter_and_check(credentails_provider_adapter)

def test_load_credentials_once(self):
credentails_provider_adapter = (
s3transfer.crt.CRTCredentialProviderAdapter(
self.botocore_credential_provider
)
crt_credentials_2 = wrapper()
self.assert_crt_credentials(
crt_credentials_2,
expected_access_key='access_key_2',
expected_secret_key='secret_key_2',
expected_token='token_2'
)
called_times = 5
for i in range(called_times):
self._call_adapter_and_check(credentails_provider_adapter)
# Assert that the load_credentails of botocore credential provider
# will only be called once
self.assertEqual(
self.botocore_credential_provider.load_credentials.call_count, 1

assert mock_credentials.get_frozen_credentials.call_count == 2

def test_raises_error_when_resolved_credentials_is_none(self):
wrapper = s3transfer.crt.CRTCredentialsWrapper(None)
with pytest.raises(NoCredentialsError):
wrapper()

def test_to_crt_credentials_provider(self, botocore_credentials):
wrapper = s3transfer.crt.CRTCredentialsWrapper(botocore_credentials)
crt_credentials_provider = wrapper.to_crt_credentials_provider()
assert isinstance(
crt_credentials_provider, awscrt.auth.AwsCredentialsProvider
)
get_credentials_future = crt_credentials_provider.get_credentials()
crt_credentials = get_credentials_future.result()
self.assert_crt_credentials(crt_credentials)


@requires_crt
Expand Down

0 comments on commit 12cd57d

Please sign in to comment.