diff --git a/CHANGES.rst b/CHANGES.rst index 288b1d08..a9e202ca 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -1,5 +1,11 @@ Changes ------- +2.3.0 (2022-05-05) +^^^^^^^^^^^^^^^^^^ +* fix encoding issue by swapping to AioAWSResponse and AioAWSRequest to behave more + like botocore +* fix exceptions mappings + 2.2.0 (2022-03-16) ^^^^^^^^^^^^^^^^^^ * remove deprecated APIs diff --git a/aiobotocore/__init__.py b/aiobotocore/__init__.py index 04188a16..82190396 100644 --- a/aiobotocore/__init__.py +++ b/aiobotocore/__init__.py @@ -1 +1 @@ -__version__ = '2.2.0' +__version__ = '2.3.0' diff --git a/aiobotocore/_endpoint_helpers.py b/aiobotocore/_endpoint_helpers.py index 16d71dfb..78974f8e 100644 --- a/aiobotocore/_endpoint_helpers.py +++ b/aiobotocore/_endpoint_helpers.py @@ -1,5 +1,4 @@ import aiohttp.http_exceptions -from aiohttp.client_reqrep import ClientResponse import asyncio import botocore.retryhandler import wrapt @@ -33,67 +32,3 @@ class _IOBaseWrapper(wrapt.ObjectProxy): def close(self): # this stream should not be closed by aiohttp, like 1.x pass - - -# This is similar to botocore.response.StreamingBody -class ClientResponseContentProxy(wrapt.ObjectProxy): - """Proxy object for content stream of http response. This is here in case - you want to pass around the "Body" of the response without closing the - response itself.""" - - def __init__(self, response): - super().__init__(response.__wrapped__.content) - self._self_response = response - - # Note: we don't have a __del__ method as the ClientResponse has a __del__ - # which will warn the user if they didn't close/release the response - # explicitly. A release here would mean reading all the unread data - # (which could be very large), and a close would mean being unable to re- - # use the connection, so the user MUST chose. Default is to warn + close - async def __aenter__(self): - await self._self_response.__aenter__() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - return await self._self_response.__aexit__(exc_type, exc_val, exc_tb) - - @property - def url(self): - return self._self_response.url - - def close(self): - self._self_response.close() - - -class ClientResponseProxy(wrapt.ObjectProxy): - """Proxy object for http response useful for porting from - botocore underlying http library.""" - - def __init__(self, *args, **kwargs): - super().__init__(ClientResponse(*args, **kwargs)) - - # this matches ClientResponse._body - self._self_body = None - - @property - def status_code(self): - return self.status - - @status_code.setter - def status_code(self, value): - # botocore tries to set this, see: - # https://github.com/aio-libs/aiobotocore/issues/190 - # Luckily status is an attribute we can set - self.status = value - - @property - def content(self): - return self._self_body - - @property - def raw(self): - return ClientResponseContentProxy(self) - - async def read(self): - self._self_body = await self.__wrapped__.read() - return self._self_body diff --git a/aiobotocore/_helpers.py b/aiobotocore/_helpers.py new file mode 100644 index 00000000..e6f7e835 --- /dev/null +++ b/aiobotocore/_helpers.py @@ -0,0 +1,16 @@ +import inspect + + +async def resolve_awaitable(obj): + if inspect.isawaitable(obj): + return await obj + + return obj + + +async def async_any(items): + for item in items: + if await resolve_awaitable(item): + return True + + return False diff --git a/aiobotocore/awsrequest.py b/aiobotocore/awsrequest.py new file mode 100644 index 00000000..3315823b --- /dev/null +++ b/aiobotocore/awsrequest.py @@ -0,0 +1,30 @@ +from botocore.awsrequest import AWSResponse +import botocore.utils + + +class AioAWSResponse(AWSResponse): + # Unlike AWSResponse, these return awaitables + + async def _content_prop(self): + """Content of the response as bytes.""" + + if self._content is None: + # NOTE: this will cache the data in self.raw + self._content = await self.raw.read() or bytes() + + return self._content + + @property + def content(self): + return self._content_prop() + + async def _text_prop(self): + encoding = botocore.utils.get_encoding_from_headers(self.headers) + if encoding: + return (await self.content).decode(encoding) + else: + return (await self.content).decode('utf-8') + + @property + def text(self): + return self._text_prop() diff --git a/aiobotocore/client.py b/aiobotocore/client.py index 43fd0431..d8384bbb 100644 --- a/aiobotocore/client.py +++ b/aiobotocore/client.py @@ -15,6 +15,8 @@ from .discovery import AioEndpointDiscoveryManager, AioEndpointDiscoveryHandler from .retries import adaptive from . import waiter +from .retries import standard + history_recorder = get_global_history_recorder() @@ -124,11 +126,46 @@ def _register_retries(self, client): elif retry_mode == 'legacy': self._register_legacy_retries(client) + def _register_v2_standard_retries(self, client): + max_attempts = client.meta.config.retries.get('total_max_attempts') + kwargs = {'client': client} + if max_attempts is not None: + kwargs['max_attempts'] = max_attempts + standard.register_retry_handler(**kwargs) + def _register_v2_adaptive_retries(self, client): # See comment in `_register_retries`. # Note that this `adaptive` module is an aiobotocore reimplementation. adaptive.register_retry_handler(client) + def _register_legacy_retries(self, client): + endpoint_prefix = client.meta.service_model.endpoint_prefix + service_id = client.meta.service_model.service_id + service_event_name = service_id.hyphenize() + + # First, we load the entire retry config for all services, + # then pull out just the information we need. + original_config = self._loader.load_data('_retry') + if not original_config: + return + + retries = self._transform_legacy_retries(client.meta.config.retries) + retry_config = self._retry_config_translator.build_retry_config( + endpoint_prefix, original_config.get('retry', {}), + original_config.get('definitions', {}), + retries + ) + + logger.debug("Registering retry handlers for service: %s", + client.meta.service_model.service_name) + handler = self._retry_handler_factory.create_retry_handler( + retry_config, endpoint_prefix) + unique_id = 'retry-config-%s' % service_event_name + client.meta.events.register( + 'needs-retry.%s' % service_event_name, handler, + unique_id=unique_id + ) + def _register_s3_events(self, client, endpoint_bridge, endpoint_url, client_config, scoped_config): if client.meta.service_model.service_name != 's3': diff --git a/aiobotocore/endpoint.py b/aiobotocore/endpoint.py index b5cfc867..8887645b 100644 --- a/aiobotocore/endpoint.py +++ b/aiobotocore/endpoint.py @@ -1,17 +1,14 @@ -import aiohttp import asyncio -import aiohttp.http_exceptions from botocore.endpoint import EndpointCreator, Endpoint, DEFAULT_TIMEOUT, \ MAX_POOL_CONNECTIONS, logger, history_recorder, create_request_object, \ - is_valid_ipv6_endpoint_url, is_valid_endpoint_url, handle_checksum_body -from botocore.exceptions import ConnectionClosedError + is_valid_ipv6_endpoint_url, is_valid_endpoint_url, HTTPClientError from botocore.hooks import first_non_none_response from urllib3.response import HTTPHeaderDict from aiobotocore.httpsession import AIOHTTPSession from aiobotocore.response import StreamingBody -from aiobotocore._endpoint_helpers import ClientResponseProxy # noqa: F401, E501 lgtm [py/unused-import] +from aiobotocore.httpchecksum import handle_checksum_body async def convert_to_response_dict(http_response, operation_model): @@ -37,21 +34,21 @@ async def convert_to_response_dict(http_response, operation_model): # aiohttp's CIMultiDict camel cases the headers :( 'headers': HTTPHeaderDict( {k.decode('utf-8').lower(): v.decode('utf-8') - for k, v in http_response.raw_headers}), + for k, v in http_response.raw.raw_headers}), 'status_code': http_response.status_code, 'context': { 'operation_name': operation_model.name, } } if response_dict['status_code'] >= 300: - response_dict['body'] = await http_response.read() + response_dict['body'] = await http_response.content elif operation_model.has_event_stream_output: response_dict['body'] = http_response.raw elif operation_model.has_streaming_output: length = response_dict['headers'].get('content-length') response_dict['body'] = StreamingBody(http_response.raw, length) else: - response_dict['body'] = await http_response.read() + response_dict['body'] = await http_response.content return response_dict @@ -150,13 +147,8 @@ async def _do_get_response(self, request, operation_model, context): http_response = first_non_none_response(responses) if http_response is None: http_response = await self._send(request) - except aiohttp.ClientConnectionError as e: - e.request = request # botocore expects the request property + except HTTPClientError as e: return None, e - except aiohttp.http_exceptions.BadStatusLine: - better_exception = ConnectionClosedError( - endpoint_url=request.url, request=request) - return None, better_exception except Exception as e: logger.debug("Exception received when sending HTTP request.", exc_info=True) @@ -165,7 +157,7 @@ async def _do_get_response(self, request, operation_model, context): # This returns the http_response and the parsed_data. response_dict = await convert_to_response_dict(http_response, operation_model) - handle_checksum_body( + await handle_checksum_body( http_response, response_dict, context, operation_model, ) diff --git a/aiobotocore/handlers.py b/aiobotocore/handlers.py index 48f7c0ff..0d08d666 100644 --- a/aiobotocore/handlers.py +++ b/aiobotocore/handlers.py @@ -1,5 +1,51 @@ from botocore.handlers import _get_presigned_url_source_and_destination_regions, \ - _get_cross_region_presigned_url + _get_cross_region_presigned_url, ETree, logger, XMLParseError + + +async def check_for_200_error(response, **kwargs): + # From: http://docs.aws.amazon.com/AmazonS3/latest/API/RESTObjectCOPY.html + # There are two opportunities for a copy request to return an error. One + # can occur when Amazon S3 receives the copy request and the other can + # occur while Amazon S3 is copying the files. If the error occurs before + # the copy operation starts, you receive a standard Amazon S3 error. If the + # error occurs during the copy operation, the error response is embedded in + # the 200 OK response. This means that a 200 OK response can contain either + # a success or an error. Make sure to design your application to parse the + # contents of the response and handle it appropriately. + # + # So this handler checks for this case. Even though the server sends a + # 200 response, conceptually this should be handled exactly like a + # 500 response (with respect to raising exceptions, retries, etc.) + # We're connected *before* all the other retry logic handlers, so as long + # as we switch the error code to 500, we'll retry the error as expected. + if response is None: + # A None response can happen if an exception is raised while + # trying to retrieve the response. See Endpoint._get_response(). + return + http_response, parsed = response + if await _looks_like_special_case_error(http_response): + logger.debug("Error found for response with 200 status code, " + "errors: %s, changing status code to " + "500.", parsed) + http_response.status_code = 500 + + +async def _looks_like_special_case_error(http_response): + if http_response.status_code == 200: + try: + parser = ETree.XMLParser( + target=ETree.TreeBuilder(), + encoding='utf-8') + parser.feed(await http_response.content) + root = parser.close() + except XMLParseError: + # In cases of network disruptions, we may end up with a partial + # streamed response from S3. We need to treat these cases as + # 500 Service Errors and try again. + return True + if root.tag == 'Error': + return True + return False async def inject_presigned_url_ec2(params, request_signer, model, **kwargs): @@ -36,3 +82,21 @@ async def inject_presigned_url_rds(params, request_signer, model, **kwargs): url = await _get_cross_region_presigned_url( request_signer, params, model, src, dest) params['body']['PreSignedUrl'] = url + + +async def parse_get_bucket_location(parsed, http_response, **kwargs): + # s3.GetBucketLocation cannot be modeled properly. To + # account for this we just manually parse the XML document. + # The "parsed" passed in only has the ResponseMetadata + # filled out. This handler will fill in the LocationConstraint + # value. + if http_response.raw is None: + return + response_body = await http_response.content + parser = ETree.XMLParser( + target=ETree.TreeBuilder(), + encoding='utf-8') + parser.feed(response_body) + root = parser.close() + region = root.text + parsed['LocationConstraint'] = region diff --git a/aiobotocore/hooks.py b/aiobotocore/hooks.py index 496d5f74..9c993daa 100644 --- a/aiobotocore/hooks.py +++ b/aiobotocore/hooks.py @@ -1,6 +1,30 @@ -import asyncio - from botocore.hooks import HierarchicalEmitter, logger +from botocore.handlers import \ + inject_presigned_url_rds as boto_inject_presigned_url_rds, \ + inject_presigned_url_ec2 as boto_inject_presigned_url_ec2, \ + parse_get_bucket_location as boto_parse_get_bucket_location, \ + check_for_200_error as boto_check_for_200_error +from botocore.signers import \ + add_generate_presigned_url as boto_add_generate_presigned_url, \ + add_generate_presigned_post as boto_add_generate_presigned_post, \ + add_generate_db_auth_token as boto_add_generate_db_auth_token + +from ._helpers import resolve_awaitable +from .signers import add_generate_presigned_url, add_generate_presigned_post, \ + add_generate_db_auth_token +from .handlers import inject_presigned_url_ec2, inject_presigned_url_rds, \ + parse_get_bucket_location, check_for_200_error + + +_HANDLER_MAPPING = { + boto_inject_presigned_url_ec2: inject_presigned_url_ec2, + boto_inject_presigned_url_rds: inject_presigned_url_rds, + boto_add_generate_presigned_url: add_generate_presigned_url, + boto_add_generate_presigned_post: add_generate_presigned_post, + boto_add_generate_db_auth_token: add_generate_db_auth_token, + boto_parse_get_bucket_location: parse_get_bucket_location, + boto_check_for_200_error: check_for_200_error +} class AioHierarchicalEmitter(HierarchicalEmitter): @@ -23,11 +47,7 @@ async def _emit(self, event_name, kwargs, stop_on_response=False): logger.debug('Event %s: calling handler %s', event_name, handler) # Await the handler if its a coroutine. - if asyncio.iscoroutinefunction(handler): - response = await handler(**kwargs) - else: - response = handler(**kwargs) - + response = await resolve_awaitable(handler(**kwargs)) responses.append((handler, response)) if stop_on_response and response is not None: return responses @@ -39,3 +59,11 @@ async def emit_until_response(self, event_name, **kwargs): return responses[-1] else: return None, None + + def _verify_and_register(self, event_name, handler, unique_id, + register_method, unique_id_uses_count): + handler = _HANDLER_MAPPING.get(handler, handler) + + self._verify_is_callable(handler) + self._verify_accept_kwargs(handler) + register_method(event_name, handler, unique_id, unique_id_uses_count) diff --git a/aiobotocore/httpchecksum.py b/aiobotocore/httpchecksum.py new file mode 100644 index 00000000..55e3653b --- /dev/null +++ b/aiobotocore/httpchecksum.py @@ -0,0 +1,59 @@ +from botocore.httpchecksum import logger, _CHECKSUM_CLS, base64, \ + FlexibleChecksumError, _handle_streaming_response + + +async def handle_checksum_body(http_response, response, context, operation_model): + headers = response["headers"] + checksum_context = context.get("checksum", {}) + algorithms = checksum_context.get("response_algorithms") + + if not algorithms: + return + + for algorithm in algorithms: + header_name = "x-amz-checksum-%s" % algorithm + # If the header is not found, check the next algorithm + if header_name not in headers: + continue + + # If a - is in the checksum this is not valid Base64. S3 returns + # checksums that include a -# suffix to indicate a checksum derived + # from the hash of all part checksums. We cannot wrap this response + if "-" in headers[header_name]: + continue + + if operation_model.has_streaming_output: + response["body"] = _handle_streaming_response( + http_response, response, algorithm + ) + else: + response["body"] = await _handle_bytes_response( + http_response, response, algorithm + ) + + # Expose metadata that the checksum check actually occurred + checksum_context = response["context"].get("checksum", {}) + checksum_context["response_algorithm"] = algorithm + response["context"]["checksum"] = checksum_context + return + + logger.info( + f'Skipping checksum validation. Response did not contain one of the ' + f'following algorithms: {algorithms}.' + ) + + +async def _handle_bytes_response(http_response, response, algorithm): + body = await http_response.content + header_name = "x-amz-checksum-%s" % algorithm + checksum_cls = _CHECKSUM_CLS.get(algorithm) + checksum = checksum_cls() + checksum.update(body) + expected = response["headers"][header_name] + if checksum.digest() != base64.b64decode(expected): + error_msg = ( + "Expected checksum %s did not match calculated checksum: %s" + % (expected, checksum.b64digest(),) + ) + raise FlexibleChecksumError(error_msg=error_msg) + return body diff --git a/aiobotocore/httpsession.py b/aiobotocore/httpsession.py index 128c44a8..6ad6d573 100644 --- a/aiobotocore/httpsession.py +++ b/aiobotocore/httpsession.py @@ -6,7 +6,8 @@ import aiohttp # lgtm [py/import-and-import-from] from aiohttp import ClientSSLError, ClientConnectorError, ClientProxyConnectionError, \ - ClientHttpProxyError, ServerTimeoutError, ServerDisconnectedError + ClientHttpProxyError, ServerTimeoutError, ServerDisconnectedError, \ + ClientConnectionError from aiohttp.client import URL from multidict import MultiDict @@ -15,9 +16,9 @@ EndpointConnectionError, ProxyConnectionError, ConnectTimeoutError, \ ConnectionClosedError, HTTPClientError, ReadTimeoutError, logger, get_cert_path, \ ensure_boolean, urlparse, mask_proxy_url +import aiobotocore.awsrequest -from aiobotocore._endpoint_helpers import _text, _IOBaseWrapper, \ - ClientResponseProxy +from aiobotocore._endpoint_helpers import _text, _IOBaseWrapper class AIOHTTPSession: @@ -87,18 +88,23 @@ def __init__( if ca_certs: ssl_context.load_verify_locations(ca_certs, None, None) - self._connector = aiohttp.TCPConnector( + self._create_connector = lambda: aiohttp.TCPConnector( limit=max_pool_connections, verify_ssl=bool(verify), ssl=ssl_context, - **connector_args) + **self._connector_args + ) + self._connector = None async def __aenter__(self): + assert not self._session and not self._connector + + self._connector = self._create_connector() + self._session = aiohttp.ClientSession( connector=self._connector, timeout=self._timeout, skip_auto_headers={'CONTENT-TYPE'}, - response_class=ClientResponseProxy, auto_decompress=False) return self @@ -106,6 +112,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb): if self._session: await self._session.__aexit__(exc_type, exc_val, exc_tb) self._session = None + self._connector = None async def close(self): await self.__aexit__(None, None, None) @@ -168,34 +175,42 @@ async def send(self, request): data = _IOBaseWrapper(data) url = URL(url, encoded=True) - resp = await self._session.request( + response = await self._session.request( request.method, url=url, headers=headers_, data=data, proxy=proxy_url, proxy_headers=proxy_headers ) + http_response = aiobotocore.awsrequest.AioAWSResponse( + str(response.url), + response.status, + response.headers, + response + ) + if not request.stream_output: # Cause the raw stream to be exhausted immediately. We do it # this way instead of using preload_content because # preload_content will never buffer chunked responses - await resp.read() + await http_response.content - return resp + return http_response except ClientSSLError as e: raise SSLError(endpoint_url=request.url, error=e) - except (ClientConnectorError, socket.gaierror) as e: - raise EndpointConnectionError(endpoint_url=request.url, error=e) except (ClientProxyConnectionError, ClientHttpProxyError) as e: raise ProxyConnectionError(proxy_url=mask_proxy_url(proxy_url), error=e) - except ServerTimeoutError as e: - raise ConnectTimeoutError(endpoint_url=request.url, error=e) - except asyncio.TimeoutError as e: - raise ReadTimeoutError(endpoint_url=request.url, error=e) - except (ServerDisconnectedError, aiohttp.ClientPayloadError) as e: + except (ServerDisconnectedError, aiohttp.ClientPayloadError, + aiohttp.http_exceptions.BadStatusLine) as e: raise ConnectionClosedError( error=e, request=request, endpoint_url=request.url ) + except ServerTimeoutError as e: + raise ConnectTimeoutError(endpoint_url=request.url, error=e) + except (ClientConnectorError, ClientConnectionError, socket.gaierror) as e: + raise EndpointConnectionError(endpoint_url=request.url, error=e) + except asyncio.TimeoutError as e: + raise ReadTimeoutError(endpoint_url=request.url, error=e) except Exception as e: message = 'Exception received when sending urllib3 HTTP request' logger.debug(message, exc_info=True) diff --git a/aiobotocore/response.py b/aiobotocore/response.py index 6f080b6c..fa193d5d 100644 --- a/aiobotocore/response.py +++ b/aiobotocore/response.py @@ -50,7 +50,7 @@ async def read(self, amt=None): """ # botocore to aiohttp mapping try: - chunk = await self.__wrapped__.read(amt if amt is not None else -1) + chunk = await self.__wrapped__.content.read(amt if amt is not None else -1) except asyncio.TimeoutError as e: raise AioReadTimeoutError(endpoint_url=self.__wrapped__.url, error=e) @@ -134,12 +134,12 @@ async def get_response(operation_model, http_response): # If it looks like an error, in the streaming response case we # need to actually grab the contents. if response_dict['status_code'] >= 300: - response_dict['body'] = http_response.content + response_dict['body'] = await http_response.content elif operation_model.has_streaming_output: response_dict['body'] = StreamingBody( http_response.raw, response_dict['headers'].get('content-length')) else: - response_dict['body'] = http_response.content + response_dict['body'] = await http_response.content parser = parsers.create_parser(protocol) if asyncio.iscoroutinefunction(parser.parse): diff --git a/aiobotocore/retries/special.py b/aiobotocore/retries/special.py new file mode 100644 index 00000000..7c1b9508 --- /dev/null +++ b/aiobotocore/retries/special.py @@ -0,0 +1,18 @@ +from botocore.retries.special import RetryDDBChecksumError, crc32, logger + + +class AioRetryDDBChecksumError(RetryDDBChecksumError): + async def is_retryable(self, context): + service_name = context.operation_model.service_model.service_name + if service_name != self._SERVICE_NAME: + return False + if context.http_response is None: + return False + checksum = context.http_response.headers.get(self._CHECKSUM_HEADER) + if checksum is None: + return False + actual_crc32 = crc32(await context.http_response.content) & 0xffffffff + if actual_crc32 != int(checksum): + logger.debug("DynamoDB crc32 checksum does not match, " + "expected: %s, actual: %s", checksum, actual_crc32) + return True diff --git a/aiobotocore/retries/standard.py b/aiobotocore/retries/standard.py new file mode 100644 index 00000000..6de38073 --- /dev/null +++ b/aiobotocore/retries/standard.py @@ -0,0 +1,92 @@ +from botocore.retries.standard import RetryHandler, RetryPolicy, logger, \ + StandardRetryConditions, OrRetryChecker, DEFAULT_MAX_ATTEMPTS, RetryQuotaChecker, \ + quota, ExponentialBackoff, RetryEventAdapter, MaxAttemptsChecker, \ + TransientRetryableChecker, ThrottledRetryableChecker, ModeledRetryableChecker, \ + special + +from .._helpers import resolve_awaitable, async_any +from .special import AioRetryDDBChecksumError + + +def register_retry_handler(client, max_attempts=DEFAULT_MAX_ATTEMPTS): + retry_quota = RetryQuotaChecker(quota.RetryQuota()) + + service_id = client.meta.service_model.service_id + service_event_name = service_id.hyphenize() + client.meta.events.register('after-call.%s' % service_event_name, + retry_quota.release_retry_quota) + + handler = AioRetryHandler( + retry_policy=AioRetryPolicy( + retry_checker=AioStandardRetryConditions(max_attempts=max_attempts), + retry_backoff=ExponentialBackoff(), + ), + retry_event_adapter=RetryEventAdapter(), + retry_quota=retry_quota, + ) + + unique_id = 'retry-config-%s' % service_event_name + client.meta.events.register( + 'needs-retry.%s' % service_event_name, handler.needs_retry, + unique_id=unique_id + ) + return handler + + +class AioRetryHandler(RetryHandler): + async def needs_retry(self, **kwargs): + """Connect as a handler to the needs-retry event.""" + retry_delay = None + context = self._retry_event_adapter.create_retry_context(**kwargs) + if await self._retry_policy.should_retry(context): + # Before we can retry we need to ensure we have sufficient + # capacity in our retry quota. + if self._retry_quota.acquire_retry_quota(context): + retry_delay = self._retry_policy.compute_retry_delay(context) + logger.debug("Retry needed, retrying request after " + "delay of: %s", retry_delay) + else: + logger.debug("Retry needed but retry quota reached, " + "not retrying request.") + else: + logger.debug("Not retrying request.") + self._retry_event_adapter.adapt_retry_response_from_context( + context) + return retry_delay + + +class AioRetryPolicy(RetryPolicy): + async def should_retry(self, context): + return await resolve_awaitable(self._retry_checker.is_retryable(context)) + + +class AioStandardRetryConditions(StandardRetryConditions): + def __init__(self, max_attempts=DEFAULT_MAX_ATTEMPTS): # noqa: E501, lgtm [py/missing-call-to-init] + # Note: This class is for convenience so you can have the + # standard retry condition in a single class. + self._max_attempts_checker = MaxAttemptsChecker(max_attempts) + self._additional_checkers = AioOrRetryChecker([ + TransientRetryableChecker(), + ThrottledRetryableChecker(), + ModeledRetryableChecker(), + AioOrRetryChecker([ + special.RetryIDPCommunicationError(), + AioRetryDDBChecksumError(), + ]) + ]) + + async def is_retryable(self, context): + return ( + self._max_attempts_checker.is_retryable(context) + and await resolve_awaitable( + self._additional_checkers.is_retryable(context) + ) + ) + + +class AioOrRetryChecker(OrRetryChecker): + async def is_retryable(self, context): + return await async_any( + checker.is_retryable(context) + for checker in self._checkers + ) diff --git a/aiobotocore/retryhandler.py b/aiobotocore/retryhandler.py new file mode 100644 index 00000000..ef32f058 --- /dev/null +++ b/aiobotocore/retryhandler.py @@ -0,0 +1,183 @@ +from botocore.retryhandler import CRC32Checker, logger, crc32, ChecksumError, \ + create_retry_action_from_config, RetryHandler, MultiChecker, \ + MaxAttemptsDecorator, _extract_retryable_exception, \ + ServiceErrorCodeChecker, HTTPStatusCodeChecker, ExceptionRaiser + +from ._helpers import resolve_awaitable + + +def create_retry_handler(config, operation_name=None): + checker = create_checker_from_retry_config( + config, operation_name=operation_name) + action = create_retry_action_from_config( + config, operation_name=operation_name) + return AioRetryHandler(checker=checker, action=action) + + +def create_checker_from_retry_config(config, operation_name=None): + checkers = [] + max_attempts = None + retryable_exceptions = [] + if '__default__' in config: + policies = config['__default__'].get('policies', []) + max_attempts = config['__default__']['max_attempts'] + for key in policies: + current_config = policies[key] + checkers.append(_create_single_checker(current_config)) + retry_exception = _extract_retryable_exception(current_config) + if retry_exception is not None: + retryable_exceptions.extend(retry_exception) + if operation_name is not None and config.get(operation_name) is not None: + operation_policies = config[operation_name]['policies'] + for key in operation_policies: + checkers.append(_create_single_checker(operation_policies[key])) + retry_exception = _extract_retryable_exception( + operation_policies[key]) + if retry_exception is not None: + retryable_exceptions.extend(retry_exception) + if len(checkers) == 1: + # Don't need to use a MultiChecker + return AioMaxAttemptsDecorator(checkers[0], max_attempts=max_attempts) + else: + multi_checker = AioMultiChecker(checkers) + return AioMaxAttemptsDecorator( + multi_checker, max_attempts=max_attempts, + retryable_exceptions=tuple(retryable_exceptions)) + + +def _create_single_checker(config): + if 'response' in config['applies_when']: + return _create_single_response_checker( + config['applies_when']['response']) + elif 'socket_errors' in config['applies_when']: + return ExceptionRaiser() + + +def _create_single_response_checker(response): + if 'service_error_code' in response: + checker = ServiceErrorCodeChecker( + status_code=response['http_status_code'], + error_code=response['service_error_code']) + elif 'http_status_code' in response: + checker = HTTPStatusCodeChecker( + status_code=response['http_status_code']) + elif 'crc32body' in response: + checker = AioCRC32Checker(header=response['crc32body']) + else: + # TODO: send a signal. + raise ValueError("Unknown retry policy") + return checker + + +class AioRetryHandler(RetryHandler): + async def _call(self, attempts, response, caught_exception, **kwargs): + """Handler for a retry. + + Intended to be hooked up to an event handler (hence the **kwargs), + this will process retries appropriately. + + """ + checker_kwargs = { + 'attempt_number': attempts, + 'response': response, + 'caught_exception': caught_exception + } + if isinstance(self._checker, MaxAttemptsDecorator): + retries_context = kwargs['request_dict']['context'].get('retries') + checker_kwargs.update({'retries_context': retries_context}) + + if await resolve_awaitable(self._checker(**checker_kwargs)): + result = self._action(attempts=attempts) + logger.debug("Retry needed, action of: %s", result) + return result + logger.debug("No retry needed.") + + def __call__(self, *args, **kwargs): + return self._call(*args, **kwargs) # return awaitable + + +class AioMaxAttemptsDecorator(MaxAttemptsDecorator): + async def _call(self, attempt_number, response, caught_exception, + retries_context): + if retries_context: + retries_context['max'] = max( + retries_context.get('max', 0), self._max_attempts + ) + + should_retry = await self._should_retry(attempt_number, response, + caught_exception) + if should_retry: + if attempt_number >= self._max_attempts: + # explicitly set MaxAttemptsReached + if response is not None and 'ResponseMetadata' in response[1]: + response[1]['ResponseMetadata']['MaxAttemptsReached'] = True + logger.debug("Reached the maximum number of retry " + "attempts: %s", attempt_number) + return False + else: + return should_retry + else: + return False + + def __call__(self, *args, **kwargs): + return self._call(*args, **kwargs) + + async def _should_retry(self, attempt_number, response, caught_exception): + if self._retryable_exceptions and \ + attempt_number < self._max_attempts: + try: + return await resolve_awaitable( + self._checker(attempt_number, response, caught_exception)) + except self._retryable_exceptions as e: + logger.debug("retry needed, retryable exception caught: %s", + e, exc_info=True) + return True + else: + # If we've exceeded the max attempts we just let the exception + # propagate if one has occurred. + return await resolve_awaitable( + self._checker(attempt_number, response, caught_exception)) + + +class AioMultiChecker(MultiChecker): + async def _call(self, attempt_number, response, caught_exception): + for checker in self._checkers: + checker_response = await resolve_awaitable( + checker(attempt_number, response, + caught_exception)) + if checker_response: + return checker_response + return False + + def __call__(self, *args, **kwargs): + return self._call(*args, **kwargs) + + +class AioCRC32Checker(CRC32Checker): + async def _call(self, attempt_number, response, caught_exception): + if response is not None: + return await self._check_response(attempt_number, response) + elif caught_exception is not None: + return self._check_caught_exception( + attempt_number, caught_exception) + else: + raise ValueError("Both response and caught_exception are None.") + + def __call__(self, *args, **kwargs): + return self._call(*args, **kwargs) + + async def _check_response(self, attempt_number, response): + http_response = response[0] + expected_crc = http_response.headers.get(self._header_name) + if expected_crc is None: + logger.debug("crc32 check skipped, the %s header is not " + "in the http response.", self._header_name) + else: + actual_crc32 = crc32(await response[0].content) & 0xffffffff + if not actual_crc32 == int(expected_crc): + logger.debug( + "retry needed: crc32 check failed, expected != actual: " + "%s != %s", int(expected_crc), actual_crc32) + raise ChecksumError(checksum_type='crc32', + expected_checksum=int(expected_crc), + actual_checksum=actual_crc32) diff --git a/aiobotocore/session.py b/aiobotocore/session.py index 938b3218..0fbea6a2 100644 --- a/aiobotocore/session.py +++ b/aiobotocore/session.py @@ -1,36 +1,18 @@ from botocore.session import Session, EVENT_ALIASES, ServiceModel, \ UnknownServiceError, copy - from botocore import UNSIGNED -from botocore import retryhandler, translate +from botocore import translate from botocore.exceptions import PartialCredentialsError + +from . import retryhandler from .client import AioClientCreator, AioBaseClient from .hooks import AioHierarchicalEmitter from .parsers import AioResponseParserFactory -from .signers import add_generate_presigned_url, add_generate_presigned_post, \ - add_generate_db_auth_token -from .handlers import inject_presigned_url_ec2, inject_presigned_url_rds -from botocore.handlers import \ - inject_presigned_url_rds as boto_inject_presigned_url_rds, \ - inject_presigned_url_ec2 as boto_inject_presigned_url_ec2 -from botocore.signers import \ - add_generate_presigned_url as boto_add_generate_presigned_url, \ - add_generate_presigned_post as boto_add_generate_presigned_post, \ - add_generate_db_auth_token as boto_add_generate_db_auth_token from .configprovider import AioSmartDefaultsConfigStoreFactory from .credentials import create_credential_resolver, AioCredentials from .utils import AioIMDSRegionProvider -_HANDLER_MAPPING = { - boto_inject_presigned_url_ec2: inject_presigned_url_ec2, - boto_inject_presigned_url_rds: inject_presigned_url_rds, - boto_add_generate_presigned_url: add_generate_presigned_url, - boto_add_generate_presigned_post: add_generate_presigned_post, - boto_add_generate_db_auth_token: add_generate_db_auth_token, -} - - class ClientCreatorContext: def __init__(self, coro): self._coro = coro @@ -54,12 +36,6 @@ def __init__(self, session_vars=None, event_hooks=None, super().__init__(session_vars, event_hooks, include_builtin_handlers, profile) - def register(self, event_name, handler, unique_id=None, - unique_id_uses_count=False): - handler = _HANDLER_MAPPING.get(handler, handler) - - return super().register(event_name, handler, unique_id, unique_id_uses_count) - def _register_response_parser_factory(self): self._components.register_component('response_parser_factory', AioResponseParserFactory()) diff --git a/aiobotocore/utils.py b/aiobotocore/utils.py index 26810e74..250059b6 100644 --- a/aiobotocore/utils.py +++ b/aiobotocore/utils.py @@ -1,35 +1,76 @@ import asyncio import logging import json +from contextlib import asynccontextmanager +import inspect -import aiohttp import aiohttp.client_exceptions from botocore.utils import ContainerMetadataFetcher, InstanceMetadataFetcher, \ IMDSFetcher, get_environ_proxies, BadIMDSRequestError, S3RegionRedirector, \ ClientError, InstanceMetadataRegionFetcher, IMDSRegionProvider, \ - resolve_imds_endpoint_mode + resolve_imds_endpoint_mode, ReadTimeoutError, HTTPClientError, \ + DEFAULT_METADATA_SERVICE_TIMEOUT, METADATA_BASE_URL, os from botocore.exceptions import ( InvalidIMDSEndpointError, MetadataRetrievalError, ) import botocore.awsrequest +import aiobotocore.httpsession logger = logging.getLogger(__name__) RETRYABLE_HTTP_ERRORS = (aiohttp.client_exceptions.ClientError, asyncio.TimeoutError) +class _RefCountedSession(aiobotocore.httpsession.AIOHTTPSession): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.__ref_count = 0 + self.__lock = None + + @asynccontextmanager + async def acquire(self): + if not self.__lock: + self.__lock = asyncio.Lock() + + # ensure we have a session + async with self.__lock: + self.__ref_count += 1 + + try: + if self.__ref_count == 1: + await self.__aenter__() + except BaseException: + self.__ref_count -= 1 + raise + + try: + yield self + finally: + async with self.__lock: + if self.__ref_count == 1: + await self.__aexit__(None, None, None) + + self.__ref_count -= 1 + + class AioIMDSFetcher(IMDSFetcher): - class Response(object): - def __init__(self, status_code, text, url): - self.status_code = status_code - self.url = url - self.text = text - self.content = text - - def __init__(self, *args, session=None, **kwargs): - super(AioIMDSFetcher, self).__init__(*args, **kwargs) - self._trust_env = bool(get_environ_proxies(self._base_url)) - self._session = session or aiohttp.ClientSession + def __init__(self, timeout=DEFAULT_METADATA_SERVICE_TIMEOUT, # noqa: E501, lgtm [py/missing-call-to-init] + num_attempts=1, base_url=METADATA_BASE_URL, + env=None, user_agent=None, config=None, session=None): + self._timeout = timeout + self._num_attempts = num_attempts + self._base_url = self._select_base_url(base_url, config) + + if env is None: + env = os.environ.copy() + self._disabled = env.get('AWS_EC2_METADATA_DISABLED', 'false').lower() + self._disabled = self._disabled == 'true' + self._user_agent = user_agent + + self._session = session or _RefCountedSession( + timeout=self._timeout, + proxies=get_environ_proxies(self._base_url), + ) async def _fetch_metadata_token(self): self._assert_enabled() @@ -42,28 +83,26 @@ async def _fetch_metadata_token(self): request = botocore.awsrequest.AWSRequest( method='PUT', url=url, headers=headers) - timeout = aiohttp.ClientTimeout(total=self._timeout) - async with self._session(timeout=timeout, - trust_env=self._trust_env) as session: + async with self._session.acquire() as session: for i in range(self._num_attempts): try: - async with session.put(url, headers=headers) as resp: - text = await resp.text() - if resp.status == 200: - return text - elif resp.status in (404, 403, 405): - return None - elif resp.status in (400,): - raise BadIMDSRequestError(request) - except asyncio.TimeoutError: + response = await session.send(request.prepare()) + if response.status_code == 200: + return await response.text + elif response.status_code in (404, 403, 405): + return None + elif response.status_code in (400,): + raise BadIMDSRequestError(request) + except ReadTimeoutError: return None except RETRYABLE_HTTP_ERRORS as e: logger.debug( "Caught retryable HTTP exception while making metadata " "service request to %s: %s", url, e, exc_info=True) - except aiohttp.client_exceptions.ClientConnectorError as e: - if getattr(e, 'errno', None) == 8 or \ - str(getattr(e, 'os_error', None)) == \ + except HTTPClientError as e: + error = e.kwargs.get('error') + if error and getattr(error, 'errno', None) == 8 or \ + str(getattr(error, 'os_error', None)) == \ 'Domain name not found': # threaded vs async resolver raise InvalidIMDSEndpointError(endpoint=url, error=e) else: @@ -81,16 +120,17 @@ async def _get_request(self, url_path, retry_func, token=None): headers['x-aws-ec2-metadata-token'] = token self._add_user_agent(headers) - timeout = aiohttp.ClientTimeout(total=self._timeout) - async with self._session(timeout=timeout, - trust_env=self._trust_env) as session: + async with self._session.acquire() as session: for i in range(self._num_attempts): try: - async with session.get(url, headers=headers) as resp: - text = await resp.text() - response = self.Response(resp.status, text, resp.url) - - if not retry_func(response): + request = botocore.awsrequest.AWSRequest( + method='GET', url=url, headers=headers) + response = await session.send(request.prepare()) + should_retry = retry_func(response) + if inspect.isawaitable(should_retry): + should_retry = await should_retry + + if not should_retry: return response except RETRYABLE_HTTP_ERRORS as e: logger.debug( @@ -98,6 +138,37 @@ async def _get_request(self, url_path, retry_func, token=None): "service request to %s: %s", url, e, exc_info=True) raise self._RETRIES_EXCEEDED_ERROR_CLS() + async def _default_retry(self, response): + return ( + await self._is_non_ok_response(response) or + await self._is_empty(response) + ) + + async def _is_non_ok_response(self, response): + if response.status_code != 200: + await self._log_imds_response(response, 'non-200', log_body=True) + return True + return False + + async def _is_empty(self, response): + if not await response.content: + await self._log_imds_response(response, 'no body', log_body=True) + return True + return False + + async def _log_imds_response(self, response, reason_to_log, log_body=False): + statement = ( + "Metadata service returned %s response " + "with status code of %s for url: %s" + ) + logger_args = [ + reason_to_log, response.status_code, response.url + ] + if log_body: + statement += ", content body: %s" + logger_args.append(await response.content) + logger.debug(statement, *logger_args) + class AioInstanceMetadataFetcher(AioIMDSFetcher, InstanceMetadataFetcher): async def retrieve_iam_role_credentials(self): @@ -127,12 +198,11 @@ async def retrieve_iam_role_credentials(self): return {} async def _get_iam_role(self, token=None): - r = await self._get_request( + return await (await self._get_request( url_path=self._URL_PATH, retry_func=self._needs_retry_for_role_name, - token=token - ) - return r.text + token=token, + )).text async def _get_credentials(self, role_name, token=None): r = await self._get_request( @@ -140,7 +210,28 @@ async def _get_credentials(self, role_name, token=None): retry_func=self._needs_retry_for_credentials, token=token ) - return json.loads(r.text) + return json.loads(await r.text) + + async def _is_invalid_json(self, response): + try: + json.loads(await response.text) + return False + except ValueError: + await self._log_imds_response(response, 'invalid json') + return True + + async def _needs_retry_for_role_name(self, response): + return ( + await self._is_non_ok_response(response) or + await self._is_empty(response) + ) + + async def _needs_retry_for_credentials(self, response): + return ( + await self._is_non_ok_response(response) or + await self._is_empty(response) or + await self._is_invalid_json(response) + ) class AioIMDSRegionProvider(IMDSRegionProvider): @@ -194,7 +285,7 @@ async def _get_region(self): retry_func=self._default_retry, token=token ) - availability_zone = response.text + availability_zone = await response.text region = availability_zone[:-1] return region @@ -304,10 +395,13 @@ async def get_bucket_region(self, bucket, response): class AioContainerMetadataFetcher(ContainerMetadataFetcher): - def __init__(self, session=None, sleep=asyncio.sleep): + def __init__(self, session=None, sleep=asyncio.sleep): # noqa: E501, lgtm [py/missing-call-to-init] if session is None: - session = aiohttp.ClientSession - super(AioContainerMetadataFetcher, self).__init__(session, sleep) + session = _RefCountedSession( + timeout=self.TIMEOUT_SECONDS + ) + self._session = session + self._sleep = sleep async def retrieve_full_uri(self, full_url, headers=None): self._validate_allowed_url(full_url) @@ -344,25 +438,27 @@ async def _retrieve_credentials(self, full_url, extra_headers=None): async def _get_response(self, full_url, headers, timeout): try: - timeout = aiohttp.ClientTimeout(total=self.TIMEOUT_SECONDS) - async with self._session(timeout=timeout) as session: - async with session.get(full_url, headers=headers) as resp: - if resp.status != 200: - text = await resp.text() - raise MetadataRetrievalError( - error_msg=( - "Received non 200 response (%d) " - "from ECS metadata: %s" - ) % (resp.status, text)) - try: - return await resp.json() - except ValueError: - text = await resp.text() - error_msg = ( - "Unable to parse JSON returned from ECS metadata services" - ) - logger.debug('%s:%s', error_msg, text) - raise MetadataRetrievalError(error_msg=error_msg) + async with self._session.acquire() as session: + AWSRequest = botocore.awsrequest.AWSRequest + request = AWSRequest(method='GET', url=full_url, headers=headers) + response = await session.send(request.prepare()) + response_text = (await response.content).decode('utf-8') + + if response.status_code != 200: + raise MetadataRetrievalError( + error_msg=( + "Received non 200 response (%s) from ECS metadata: %s" + ) % (response.status_code, response_text) + ) + try: + return json.loads(response_text) + except ValueError: + error_msg = ( + "Unable to parse JSON returned from ECS metadata services" + ) + logger.debug('%s:%s', error_msg, response_text) + raise MetadataRetrievalError(error_msg=error_msg) + except RETRYABLE_HTTP_ERRORS as e: error_msg = ("Received error when attempting to retrieve " "ECS metadata: %s" % e) diff --git a/tests/botocore/test_credentials.py b/tests/botocore/test_credentials.py index 60e4bca8..050ab1ed 100644 --- a/tests/botocore/test_credentials.py +++ b/tests/botocore/test_credentials.py @@ -21,6 +21,7 @@ import pytest import botocore.exceptions +import wrapt from botocore.stub import Stubber from dateutil.tz import tzlocal, tzutc from botocore.utils import datetime2timestamp @@ -1330,6 +1331,14 @@ def _mock_provider(provider_cls): return mock_instance +class DummyContextWrapper(wrapt.ObjectProxy): + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + pass + + @asynccontextmanager async def _create_session(self, profile=None): session = StubbedSession(profile=profile) @@ -1341,7 +1350,7 @@ async def _create_session(self, profile=None): aws_access_key_id='spam', aws_secret_access_key='eggs', ) as sts: - self.mock_client_creator.return_value = sts + self.mock_client_creator.return_value = DummyContextWrapper(sts) assume_role_provider = AioAssumeRoleProvider( load_config=lambda: session.full_config, client_creator=self.mock_client_creator, diff --git a/tests/botocore/test_utils.py b/tests/botocore/test_utils.py index 91fa1a8b..f307f0f0 100644 --- a/tests/botocore/test_utils.py +++ b/tests/botocore/test_utils.py @@ -5,14 +5,16 @@ import itertools import unittest from typing import Union, List, Tuple +from contextlib import asynccontextmanager -from aiohttp.client_reqrep import ClientResponse, RequestInfo -from aiohttp.helpers import TimerNoop from aiohttp.client_exceptions import ClientConnectionError +from botocore.exceptions import ReadTimeoutError + from aiobotocore import utils from aiobotocore.utils import AioInstanceMetadataFetcher from botocore.utils import MetadataRetrievalError, BadIMDSRequestError -import yarl +from tests.test_response import AsyncBytesIO +from aiobotocore.awsrequest import AioAWSResponse # From class TestContainerMetadataFetcher @@ -27,20 +29,30 @@ def fake_aiohttp_session(responses: Union[List[Tuple[Union[str, object], int]], data = iter(responses) class FakeAioHttpSession(object): + @asynccontextmanager + async def acquire(self): + yield self + class FakeResponse(object): - def __init__(self, url, *args, **kwargs): - self.url = url - self._body, self.status = next(data) + def __init__(self, request, *args, **kwargs): + self.request = request + self.url = request.url + self._body, self.status_code = next(data) + self.content = self._content() + self.text = self._text() if not isinstance(self._body, str): raise self._body + async def _content(self): + return self._body.encode('utf-8') + async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_val, exc_tb): pass - async def text(self): + async def _text(self): return self._body async def json(self): @@ -55,13 +67,10 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_val, exc_tb): pass - def get(self, url, *args, **kwargs): - return self.FakeResponse(url) + async def send(self, request): + return self.FakeResponse(request) - def put(self, url, *args, **kwargs): - return self.FakeResponse(url) - - return FakeAioHttpSession + return FakeAioHttpSession() @pytest.mark.moto @@ -120,7 +129,7 @@ async def test_containermetadatafetcher_retrieve_url_not_json(): class TestInstanceMetadataFetcher(unittest.IsolatedAsyncioTestCase): async def asyncSetUp(self): - urllib3_session_send = 'aiohttp.ClientSession._request' + urllib3_session_send = 'aiobotocore.httpsession.AIOHTTPSession.send' self._urllib3_patch = mock.patch(urllib3_session_send) self._send = self._urllib3_patch.start() self._imds_responses = [] @@ -144,20 +153,13 @@ async def asyncTearDown(self): self._urllib3_patch.stop() def add_imds_response(self, body, status_code=200): - loop = asyncio.get_running_loop() - url = yarl.URL('http://169.254.169.254/') - method = 'get' - response = ClientResponse(method, url, - request_info=RequestInfo(url, method, {}), - writer=mock.AsyncMock(), - continue100=None, - timer=TimerNoop(), - traces=[], - loop=loop, - session=mock.AsyncMock()) - response.status = status_code - response._body = body - response._headers = {} + response = AioAWSResponse( + url='http://169.254.169.254/', + status_code=status_code, + headers={}, + raw=AsyncBytesIO(body) + ) + self._imds_responses.append(response) def add_get_role_name_imds_response(self, role_name=None): @@ -370,7 +372,7 @@ async def test_token_is_included(self): # Check that subsequent calls after getting the token include the token. self.assertEqual(self._send.call_count, 3) for call in self._send.call_args_list[1:]: - self.assertEqual(call.kwargs['headers']['x-aws-ec2-metadata-token'], + self.assertEqual(call[0][0].headers['x-aws-ec2-metadata-token'], 'token') self.assertEqual(result, self._expected_creds) @@ -386,7 +388,7 @@ async def test_metadata_token_not_supported_404(self): user_agent=user_agent).retrieve_iam_role_credentials() for call in self._send.call_args_list[1:]: - self.assertNotIn('x-aws-ec2-metadata-token', call.kwargs['headers']) + self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers) self.assertEqual(result, self._expected_creds) @pytest.mark.moto @@ -401,7 +403,7 @@ async def test_metadata_token_not_supported_403(self): user_agent=user_agent).retrieve_iam_role_credentials() for call in self._send.call_args_list[1:]: - self.assertNotIn('x-aws-ec2-metadata-token', call.kwargs['headers']) + self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers) self.assertEqual(result, self._expected_creds) @pytest.mark.moto @@ -416,7 +418,7 @@ async def test_metadata_token_not_supported_405(self): user_agent=user_agent).retrieve_iam_role_credentials() for call in self._send.call_args_list[1:]: - self.assertNotIn('x-aws-ec2-metadata-token', call.kwargs['headers']) + self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers) self.assertEqual(result, self._expected_creds) @pytest.mark.moto @@ -431,7 +433,7 @@ async def test_metadata_token_not_supported_timeout(self): user_agent=user_agent).retrieve_iam_role_credentials() for call in self._send.call_args_list[1:]: - self.assertNotIn('x-aws-ec2-metadata-token', call.kwargs['headers']) + self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers) self.assertEqual(result, self._expected_creds) @pytest.mark.moto @@ -446,7 +448,7 @@ async def test_token_not_supported_exhaust_retries(self): user_agent=user_agent).retrieve_iam_role_credentials() for call in self._send.call_args_list[1:]: - self.assertNotIn('x-aws-ec2-metadata-token', call.kwargs['headers']) + self.assertNotIn('x-aws-ec2-metadata-token', call[0][0].headers) self.assertEqual(result, self._expected_creds) @pytest.mark.moto @@ -515,7 +517,7 @@ async def test_idmsfetcher_get_token_bad_request(): @pytest.mark.asyncio async def test_idmsfetcher_get_token_timeout(): session = fake_aiohttp_session([ - (asyncio.TimeoutError(), 500), + (ReadTimeoutError(endpoint_url='aaa'), 500), ]) fetcher = utils.AioIMDSFetcher(num_attempts=2, @@ -554,7 +556,7 @@ async def test_idmsfetcher_retry(): user_agent='test') response = await fetcher._get_request('path', None, 'some_token') - assert response.text == 'data' + assert await response.text == 'data' session = fake_aiohttp_session([ ('blah', 500), diff --git a/tests/test_basic_s3.py b/tests/test_basic_s3.py index 76869cfa..981fd4d0 100644 --- a/tests/test_basic_s3.py +++ b/tests/test_basic_s3.py @@ -43,7 +43,7 @@ async def test_can_make_request_no_verify(s3_client): @pytest.mark.asyncio async def test_fail_proxy_request(aa_fail_proxy_config, s3_client, monkeypatch): # based on test_can_make_request - with pytest.raises(httpsession.EndpointConnectionError): + with pytest.raises(httpsession.ProxyConnectionError): await s3_client.list_buckets() diff --git a/tests/test_patches.py b/tests/test_patches.py index 8bd21a15..00aa61be 100644 --- a/tests/test_patches.py +++ b/tests/test_patches.py @@ -1,12 +1,8 @@ -import asyncio import hashlib from dill.source import getsource from itertools import chain import pytest -from yarl import URL - -from aiobotocore.endpoint import ClientResponseProxy import aiohttp from aiohttp.client import ClientResponse @@ -40,11 +36,16 @@ CanonicalNameCredentialSourcer, BotoProvider, OriginalEC2Provider, \ create_credential_resolver, get_credentials, create_mfa_serial_refresher, \ AssumeRoleWithWebIdentityCredentialFetcher, SSOCredentialFetcher, SSOProvider -from botocore.handlers import inject_presigned_url_ec2, inject_presigned_url_rds +from botocore.handlers import inject_presigned_url_ec2, inject_presigned_url_rds, \ + parse_get_bucket_location, check_for_200_error, _looks_like_special_case_error from botocore.httpsession import URLLib3Session from botocore.discovery import EndpointDiscoveryManager, EndpointDiscoveryHandler -from botocore.retries.adaptive import ClientRateLimiter, register_retry_handler +from botocore.retries import adaptive, special from botocore.retries.bucket import TokenBucket +from botocore import retryhandler +from botocore.retries import standard +from botocore.awsrequest import AWSResponse +from botocore.httpchecksum import handle_checksum_body, _handle_bytes_response # This file ensures that our private patches will work going forward. If a @@ -95,6 +96,10 @@ ClientCreator._register_retries: {'16d3064142e5f9e45b0094bbfabf7be30183f255'}, ClientCreator._register_v2_adaptive_retries: {'665ecd77d36a5abedffb746d83a44bb0a64c660a'}, + ClientCreator._register_v2_standard_retries: + {'9ec4ff68599544b4f46067b3783287862d38fb50'}, + ClientCreator._register_legacy_retries: + {'7dbd1a9d045b3d4f5bf830664f17c7bc610ee3a3'}, BaseClient._make_api_call: {'6517c7ead41bf0c70f38bb70666bffd21835ed72'}, BaseClient._make_request: {'033a386f7d1025522bea7f2bbca85edc5c8aafd2'}, @@ -232,6 +237,9 @@ HierarchicalEmitter._emit: {'5d9a6b1aea1323667a9310e707a9f0a006f8f6e8'}, HierarchicalEmitter.emit_until_response: {'23670e04e0b09a9575c7533442bca1b2972ade82'}, + HierarchicalEmitter._verify_and_register: + {'aa14572fd9d42b83793d4a9d61c680e37761d762'}, + EventAliaser.emit_until_response: {'0d635bf7ae5022b1fdde891cd9a91cd4c449fd49'}, # paginate.py @@ -270,7 +278,6 @@ Session.get_service_data: {'e28f2de9ebaf13214f1606d33349dfa8e2555923'}, Session.get_service_model: {'1c8f93e6fb9913e859e43aea9bc2546edbea8365'}, Session.get_available_regions: {'bc455d24d98fbc112ff22325ebfd12a6773cb7d4'}, - Session.register: {'39791fd2cffcea480f81e77c7daf3974581d9291'}, Session._register_smart_defaults_factory: {'24ab10e4751ada800dde24d40d1d105be76a0a14'}, @@ -301,10 +308,14 @@ {'f5294f9f811cb3cc370e4824ca106269ea1f44f9'}, ContainerMetadataFetcher._get_response: {'7e5acdd2cf0167a047e3d5ee1439565a2f79f6a6'}, - # Overrided session and dealing with proxy support + IMDSFetcher.__init__: {'a0766a5ba7dde9c26f3c51eb38d73f8e6087d492'}, IMDSFetcher._get_request: {'d06ba6890b94c819e79e27ac819454b28f704535'}, IMDSFetcher._fetch_metadata_token: {'c162c832ec24082cd2054945382d8dc6a1ec5e7b'}, + IMDSFetcher._default_retry: {'d1fa834cedfc7a2bf9957ba528eed24f600f7ef6'}, + IMDSFetcher._is_non_ok_response: {'448b80545b1946ec44ff19ebca8d4993872a6281'}, + IMDSFetcher._is_empty: {'241b141c9c352a4ef72964f8399d46cbe9a5aebc'}, + IMDSFetcher._log_imds_response: {'f1e09ad248feb167f55b11bbae735ea0e2c7b446'}, InstanceMetadataFetcher.retrieve_iam_role_credentials: {'76737f6add82a1b9a0dc590cf10bfac0c7026a2e'}, @@ -312,6 +323,13 @@ {'80073d7adc9fb604bc6235af87241f5efc296ad7'}, InstanceMetadataFetcher._get_credentials: {'1a64f59a3ca70b83700bd14deeac25af14100d58'}, + InstanceMetadataFetcher._is_invalid_json: + {'97818b51182a2507c99876a40155adda0451dd82'}, + InstanceMetadataFetcher._needs_retry_for_role_name: + {'0f1034c9de5be2d79a584e1e057b8df5b39f4514'}, + InstanceMetadataFetcher._needs_retry_for_credentials: + {'977be4286b42916779ade4c20472ec3a6a26c90d'}, + S3RegionRedirector.redirect_from_error: {'f6f765431145a9bed8e73e6a3dbc7b0d6ae5f738'}, S3RegionRedirector.get_bucket_region: @@ -335,6 +353,9 @@ # handlers.py inject_presigned_url_rds: {'5a34e1666d84f6229c54a59bffb69d46e8117b3a'}, inject_presigned_url_ec2: {'37fad2d9c53ca4f1783e32799fa8f70930f44c23'}, + parse_get_bucket_location: {'dde31b9fe4447ed6eb9b8c26ab14cc2bd3ae2c64'}, + check_for_200_error: {'94005c964d034e68bb2f079e89c93115c1f11aad'}, + _looks_like_special_case_error: {'adcf7c6f77aa123bd94e96ef0beb4ba548e55086'}, # httpsession.py URLLib3Session: {'5adede4ba9d2a80e776bfeb71127656fafff91d7'}, @@ -349,13 +370,52 @@ # retries/adaptive.py # See comments in AsyncTokenBucket: we completely replace the ClientRateLimiter # implementation from botocore. - ClientRateLimiter: {'d4ba74b924cdccf705adeb89f2c1885b4d21ce02'}, - register_retry_handler: {'d662512878511e72d1202d880ae181be6a5f9d37'}, + adaptive.ClientRateLimiter: {'d4ba74b924cdccf705adeb89f2c1885b4d21ce02'}, + adaptive.register_retry_handler: {'d662512878511e72d1202d880ae181be6a5f9d37'}, + + # retries/standard.py + standard.register_retry_handler: {'8d464a753335ce7457c5eea73e80d9a224fe7f21'}, + standard.RetryHandler.needs_retry: {'2dfc4c2d2efcd5ca00ae84ccdca4ab070d831e22'}, + standard.RetryPolicy.should_retry: {'b30eadcb94dadcdb90a5810cdeb2e3a0bc0c74c9'}, + standard.StandardRetryConditions.__init__: + {'82f00342fb50a681e431f07e63623ab3f1e39577'}, + standard.StandardRetryConditions.is_retryable: + {'4d14d1713bc2806c24b6797b2ec395a29c9b0453'}, + standard.OrRetryChecker.is_retryable: {'5ef0b84b1ef3a49bc193d76a359dbd314682856b'}, + + # retries/special.py + special.RetryDDBChecksumError.is_retryable: + {'0769cca303874f8dce47dcc93980fa0841fbaab6'}, # retries/bucket.py # See comments in AsyncTokenBucket: we completely replace the TokenBucket # implementation from botocore. TokenBucket: {'9d543c15de1d582fe99a768fd6d8bde1ed8bb930'}, + + # awsresponse.py + AWSResponse.content: {'1d74998e3e0abe52b52c251a1eae4971e65b1053'}, + AWSResponse.text: {'a724100ba9f6d51b333b8fe470fac46376d5044a'}, + + # httpchecksum.py + handle_checksum_body: {'4b9aeef18d816563624c66c57126d1ffa6fe1993'}, + _handle_bytes_response: {'76f4f9d1da968dc6dbc24fd9f59b4b8ee86799f4'}, + + # retryhandler.py + retryhandler.create_retry_handler: {'fde9dfbc581f3d571f7bf9af1a966f0d28f6d89d'}, + retryhandler.create_checker_from_retry_config: + {'3022785da77b62e0df06f048da3bb627a2e59bd5'}, + retryhandler._create_single_checker: {'517aaf8efda4bfe851d8dc024513973de1c5ffde'}, + retryhandler._create_single_response_checker: + {'f55d841e5afa5ebac6b883edf74a9d656415474b'}, + retryhandler.RetryHandler.__call__: {'0ff14b0e97db0d553e8b94a357c11187ca31ea5a'}, + retryhandler.MaxAttemptsDecorator.__call__: + {'d04ae8ff3ab82940bd7a5ffcd2aa27bf45a4817a'}, + retryhandler.MaxAttemptsDecorator._should_retry: + {'33af9b4af06372dc2a7985d6cbbf8dfbaee4be2a'}, + retryhandler.MultiChecker.__call__: {'dae2cc32aae9fa0a527630db5c5d8db96d957633'}, + retryhandler.CRC32Checker.__call__: {'4f0b55948e05a9039dc0ba62c80eb341682b85ac'}, + retryhandler.CRC32Checker._check_response: + {'bc371df204ab7138e792b782e83473e6e9b7a620'}, } @@ -377,25 +437,18 @@ def test_patches(): success = True for obj, digests in chain(_AIOHTTP_DIGESTS.items(), _API_DIGESTS.items()): - digest = hashlib.sha1(getsource(obj).encode('utf-8')).hexdigest() + + try: + source = getsource(obj) + except TypeError: + obj = obj.fget + source = getsource(obj) + + digest = hashlib.sha1(source.encode('utf-8')).hexdigest() + if digest not in digests: print("Digest of {}:{} not found in: {}".format( obj.__qualname__, digest, digests)) success = False assert success - - -# NOTE: this doesn't require moto but needs to be marked to run with coverage -@pytest.mark.moto -@pytest.mark.asyncio -async def test_set_status_code(): - resp = ClientResponseProxy( - 'GET', URL('http://foo/bar'), - loop=asyncio.get_event_loop(), - writer=None, continue100=None, timer=None, - request_info=None, - traces=None, - session=None) - resp.status_code = 500 - assert resp.status_code == 500 diff --git a/tests/test_response.py b/tests/test_response.py index 698ca040..b824833f 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -17,7 +17,11 @@ async def assert_lines(line_iterator, expected_lines): class AsyncBytesIO(io.BytesIO): - async def read(self, amt): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.content = self + + async def read(self, amt=-1): if amt == -1: # aiohttp to regular response amt = None return super().read(amt)