From 6e39d735329e5dde2896ca779267c6bb061245da Mon Sep 17 00:00:00 2001 From: Mark Chappell Date: Thu, 18 Apr 2024 13:36:53 +0200 Subject: [PATCH] S3ErrorHandler --- plugins/action/s3_object.py | 4 +- plugins/module_utils/s3.py | 87 ++++++ tests/unit/module_utils/s3/test_endpoints.py | 92 ++++++ .../{test_s3.py => s3/test_etags.py} | 106 +------ .../module_utils/s3/test_s3_error_handler.py | 263 ++++++++++++++++++ .../s3/test_validate_bucket_name.py | 34 +++ 6 files changed, 479 insertions(+), 107 deletions(-) create mode 100644 tests/unit/module_utils/s3/test_endpoints.py rename tests/unit/module_utils/{test_s3.py => s3/test_etags.py} (66%) create mode 100644 tests/unit/module_utils/s3/test_s3_error_handler.py create mode 100644 tests/unit/module_utils/s3/test_validate_bucket_name.py diff --git a/plugins/action/s3_object.py b/plugins/action/s3_object.py index f78a42fa39b..1faa07c0abb 100644 --- a/plugins/action/s3_object.py +++ b/plugins/action/s3_object.py @@ -48,7 +48,7 @@ def run(self, tmp=None, task_vars=None): # module handles error message for nonexistent files new_module_args["src"] = source except AnsibleError as e: - raise AnsibleActionFail(to_text(e)) + raise AnsibleActionFail(to_text(e)) from e wrap_async = self._task.async_val and not self._connection.has_native_async # execute the s3_object module with the updated args @@ -58,7 +58,7 @@ def run(self, tmp=None, task_vars=None): if not wrap_async: # remove a temporary path we created - self._remove_tmp_path(self._connection._shell.tmpdir) + self._remove_tmp_path(None) except AnsibleAction as e: result.update(e.result) diff --git a/plugins/module_utils/s3.py b/plugins/module_utils/s3.py index 961f36f22f0..8cbce2e288a 100644 --- a/plugins/module_utils/s3.py +++ b/plugins/module_utils/s3.py @@ -3,6 +3,7 @@ # Copyright (c) 2018 Red Hat, Inc. # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) +import functools import string from urllib.parse import urlparse @@ -14,6 +15,9 @@ HAS_MD5 = False try: + # Beware, S3 is a "special" case, it sometimes catches botocore exceptions and + # re-raises them as boto3 exceptions. + import boto3 import botocore except ImportError: pass # Handled by the calling module @@ -21,6 +25,89 @@ from ansible.module_utils.basic import to_text +from .botocore import is_boto3_error_code +from .botocore import is_boto3_error_message +from .errors import AWSErrorHandler +from .exceptions import AnsibleAWSError +from .retries import AWSRetry + +IGNORE_S3_DROP_IN_EXCEPTIONS = ["XNotImplemented", "NotImplemented", "AccessControlListNotSupported"] + + +class AnsibleS3Error(AnsibleAWSError): + pass + + +class AnsibleS3Sigv4RequiredError(AnsibleS3Error): + pass + + +class AnsibleS3PermissionsError(AnsibleS3Error): + pass + + +class AnsibleS3SupportError(AnsibleS3Error): + pass + + +class S3ErrorHandler(AWSErrorHandler): + _CUSTOM_EXCEPTION = AnsibleS3Error + + @classmethod + def _is_missing(cls): + return is_boto3_error_code( + [ + "404", + "NoSuchTagSet", + "NoSuchTagSetError", + "ObjectLockConfigurationNotFoundError", + "NoSuchBucketPolicy", + "ServerSideEncryptionConfigurationNotFoundError", + "NoSuchBucket", + "NoSuchPublicAccessBlockConfiguration", + "OwnershipControlsNotFoundError", + "NoSuchOwnershipControls", + ] + ) + + @classmethod + def common_error_handler(cls, description): + def wrapper(func): + @super(S3ErrorHandler, cls).common_error_handler(description) + @functools.wraps(func) + def handler(*args, **kwargs): + try: + return func(*args, **kwargs) + except is_boto3_error_code(["403", "AccessDenied"]) as e: + # FUTURE: there's a case to be made that this moves up into AWSErrorHandler + # for now, we'll handle this just for S3, but wait and see if it pops up in too + # many other places + raise AnsibleS3PermissionsError( + message=f"Failed to {description} (permission denied)", exception=e + ) from e + except is_boto3_error_message( # pylint: disable=duplicate-except + "require AWS Signature Version 4" + ) as e: + raise AnsibleS3Sigv4RequiredError( + message=f"Failed to {description} (not supported by cloud)", exception=e + ) from e + except is_boto3_error_code(IGNORE_S3_DROP_IN_EXCEPTIONS) as e: # pylint: disable=duplicate-except + # Unlike most of our modules, we attempt to handle non-AWS clouds. For read-only + # actions we sometimes need the ability to ignore unsupported features. + raise AnsibleS3SupportError( + message=f"Failed to {description} (not supported by cloud)", exception=e + ) from e + except botocore.exceptions.EndpointConnectionError as e: + raise cls._CUSTOM_EXCEPTION( + message=f"Failed to {description} - Invalid endpoint provided", exception=e + ) from e + except boto3.exceptions.Boto3Error as e: + raise cls._CUSTOM_EXCEPTION(message=f"Failed to {description}", exception=e) from e + + return handler + + return wrapper + def s3_head_objects(client, parts, bucket, obj, versionId): args = {"Bucket": bucket, "Key": obj} diff --git a/tests/unit/module_utils/s3/test_endpoints.py b/tests/unit/module_utils/s3/test_endpoints.py new file mode 100644 index 00000000000..98d46958658 --- /dev/null +++ b/tests/unit/module_utils/s3/test_endpoints.py @@ -0,0 +1,92 @@ +# -*- coding: utf-8 -*- +# +# (c) 2021 Red Hat Inc. +# +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +from unittest.mock import patch + +import pytest + +from ansible_collections.amazon.aws.plugins.module_utils import s3 + +mod_urlparse = "ansible_collections.amazon.aws.plugins.module_utils.s3.urlparse" + + +class UrlInfo: + def __init__(self, scheme=None, hostname=None, port=None): + self.hostname = hostname + self.scheme = scheme + self.port = port + + +@patch(mod_urlparse) +def test_is_fakes3_with_none_arg(m_urlparse): + m_urlparse.side_effect = SystemExit(1) + result = s3.is_fakes3(None) + assert not result + m_urlparse.assert_not_called() + + +@pytest.mark.parametrize( + "url,scheme,result", + [ + ("https://test-s3.amazon.com", "https", False), + ("fakes3://test-s3.amazon.com", "fakes3", True), + ("fakes3s://test-s3.amazon.com", "fakes3s", True), + ], +) +@patch(mod_urlparse) +def test_is_fakes3(m_urlparse, url, scheme, result): + m_urlparse.return_value = UrlInfo(scheme=scheme) + assert result == s3.is_fakes3(url) + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize( + "url,urlinfo,endpoint", + [ + ( + "fakes3://test-s3.amazon.com", + {"scheme": "fakes3", "hostname": "test-s3.amazon.com"}, + {"endpoint": "http://test-s3.amazon.com:80", "use_ssl": False}, + ), + ( + "fakes3://test-s3.amazon.com:8080", + {"scheme": "fakes3", "hostname": "test-s3.amazon.com", "port": 8080}, + {"endpoint": "http://test-s3.amazon.com:8080", "use_ssl": False}, + ), + ( + "fakes3s://test-s3.amazon.com", + {"scheme": "fakes3s", "hostname": "test-s3.amazon.com"}, + {"endpoint": "https://test-s3.amazon.com:443", "use_ssl": True}, + ), + ( + "fakes3s://test-s3.amazon.com:9096", + {"scheme": "fakes3s", "hostname": "test-s3.amazon.com", "port": 9096}, + {"endpoint": "https://test-s3.amazon.com:9096", "use_ssl": True}, + ), + ], +) +@patch(mod_urlparse) +def test_parse_fakes3_endpoint(m_urlparse, url, urlinfo, endpoint): + m_urlparse.return_value = UrlInfo(**urlinfo) + result = s3.parse_fakes3_endpoint(url) + assert endpoint == result + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize( + "url,scheme,use_ssl", + [ + ("https://test-s3-ceph.amazon.com", "https", True), + ("http://test-s3-ceph.amazon.com", "http", False), + ], +) +@patch(mod_urlparse) +def test_parse_ceph_endpoint(m_urlparse, url, scheme, use_ssl): + m_urlparse.return_value = UrlInfo(scheme=scheme) + result = s3.parse_ceph_endpoint(url) + assert result == {"endpoint": url, "use_ssl": use_ssl} + m_urlparse.assert_called_with(url) diff --git a/tests/unit/module_utils/test_s3.py b/tests/unit/module_utils/s3/test_etags.py similarity index 66% rename from tests/unit/module_utils/test_s3.py rename to tests/unit/module_utils/s3/test_etags.py index 3770064c5b8..b775b3b2bb0 100644 --- a/tests/unit/module_utils/test_s3.py +++ b/tests/unit/module_utils/s3/test_etags.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- # # (c) 2021 Red Hat Inc. # @@ -188,108 +189,3 @@ def test_calculate_etag_failure(m_checksum_file, m_checksum_content, using_file) with pytest.raises(SystemExit): test_method(module, content, etag, client, s3bucket_name, s3bucket_object, version) module.fail_json_aws.assert_called() - - -@pytest.mark.parametrize( - "bucket_name,result", - [ - ("docexamplebucket1", None), - ("log-delivery-march-2020", None), - ("my-hosted-content", None), - ("docexamplewebsite.com", None), - ("www.docexamplewebsite.com", None), - ("my.example.s3.bucket", None), - ("doc", None), - ("doc_example_bucket", "invalid character(s) found in the bucket name"), - ("DocExampleBucket", "invalid character(s) found in the bucket name"), - ("doc-example-bucket-", "bucket names must begin and end with a letter or number"), - ( - "this.string.has.more.than.63.characters.so.it.should.not.passed.the.validated", - "the length of an S3 bucket cannot exceed 63 characters", - ), - ("my", "the length of an S3 bucket must be at least 3 characters"), - ], -) -def test_validate_bucket_name(bucket_name, result): - assert result == s3.validate_bucket_name(bucket_name) - - -mod_urlparse = "ansible_collections.amazon.aws.plugins.module_utils.s3.urlparse" - - -class UrlInfo: - def __init__(self, scheme=None, hostname=None, port=None): - self.hostname = hostname - self.scheme = scheme - self.port = port - - -@patch(mod_urlparse) -def test_is_fakes3_with_none_arg(m_urlparse): - m_urlparse.side_effect = SystemExit(1) - result = s3.is_fakes3(None) - assert not result - m_urlparse.assert_not_called() - - -@pytest.mark.parametrize( - "url,scheme,result", - [ - ("https://test-s3.amazon.com", "https", False), - ("fakes3://test-s3.amazon.com", "fakes3", True), - ("fakes3s://test-s3.amazon.com", "fakes3s", True), - ], -) -@patch(mod_urlparse) -def test_is_fakes3(m_urlparse, url, scheme, result): - m_urlparse.return_value = UrlInfo(scheme=scheme) - assert result == s3.is_fakes3(url) - m_urlparse.assert_called_with(url) - - -@pytest.mark.parametrize( - "url,urlinfo,endpoint", - [ - ( - "fakes3://test-s3.amazon.com", - {"scheme": "fakes3", "hostname": "test-s3.amazon.com"}, - {"endpoint": "http://test-s3.amazon.com:80", "use_ssl": False}, - ), - ( - "fakes3://test-s3.amazon.com:8080", - {"scheme": "fakes3", "hostname": "test-s3.amazon.com", "port": 8080}, - {"endpoint": "http://test-s3.amazon.com:8080", "use_ssl": False}, - ), - ( - "fakes3s://test-s3.amazon.com", - {"scheme": "fakes3s", "hostname": "test-s3.amazon.com"}, - {"endpoint": "https://test-s3.amazon.com:443", "use_ssl": True}, - ), - ( - "fakes3s://test-s3.amazon.com:9096", - {"scheme": "fakes3s", "hostname": "test-s3.amazon.com", "port": 9096}, - {"endpoint": "https://test-s3.amazon.com:9096", "use_ssl": True}, - ), - ], -) -@patch(mod_urlparse) -def test_parse_fakes3_endpoint(m_urlparse, url, urlinfo, endpoint): - m_urlparse.return_value = UrlInfo(**urlinfo) - result = s3.parse_fakes3_endpoint(url) - assert endpoint == result - m_urlparse.assert_called_with(url) - - -@pytest.mark.parametrize( - "url,scheme,use_ssl", - [ - ("https://test-s3-ceph.amazon.com", "https", True), - ("http://test-s3-ceph.amazon.com", "http", False), - ], -) -@patch(mod_urlparse) -def test_parse_ceph_endpoint(m_urlparse, url, scheme, use_ssl): - m_urlparse.return_value = UrlInfo(scheme=scheme) - result = s3.parse_ceph_endpoint(url) - assert result == {"endpoint": url, "use_ssl": use_ssl} - m_urlparse.assert_called_with(url) diff --git a/tests/unit/module_utils/s3/test_s3_error_handler.py b/tests/unit/module_utils/s3/test_s3_error_handler.py new file mode 100644 index 00000000000..d1dcc5b6755 --- /dev/null +++ b/tests/unit/module_utils/s3/test_s3_error_handler.py @@ -0,0 +1,263 @@ +# -*- coding: utf-8 -*- + +# Copyright: Contributors to the Ansible project +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +try: + import botocore +except ImportError: + pass + +import pytest + +from ansible_collections.amazon.aws.plugins.module_utils.botocore import HAS_BOTO3 +from ansible_collections.amazon.aws.plugins.module_utils.s3 import AnsibleS3Error +from ansible_collections.amazon.aws.plugins.module_utils.s3 import AnsibleS3PermissionsError +from ansible_collections.amazon.aws.plugins.module_utils.s3 import AnsibleS3Sigv4RequiredError +from ansible_collections.amazon.aws.plugins.module_utils.s3 import AnsibleS3SupportError +from ansible_collections.amazon.aws.plugins.module_utils.s3 import S3ErrorHandler + +if not HAS_BOTO3: + pytestmark = pytest.mark.skip("test_s3_error_handler.py requires the python modules 'boto3' and 'botocore'") + + +class TestS3DeletionHandler: + def test_no_failures(self): + self.counter = 0 + + @S3ErrorHandler.deletion_error_handler("no error") + def no_failures(): + self.counter += 1 + + no_failures() + assert self.counter == 1 + + def test_client_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "MalformedPolicyDocument", + "Message": "Policy document should not specify a principal", + } + } + + @S3ErrorHandler.deletion_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "Something bad") + + with pytest.raises(AnsibleS3Error) as e_info: + raise_client_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.ClientError) + assert "do something" in raised.message + assert "Something bad" in str(raised.exception) + + def test_ignore_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "404", + "Message": "Not found", + } + } + + @S3ErrorHandler.deletion_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "I couldn't find it") + + ret_val = raise_client_error() + assert self.counter == 1 + assert ret_val is False + + +class TestS3ListHandler: + def test_no_failures(self): + self.counter = 0 + + @S3ErrorHandler.list_error_handler("no error") + def no_failures(): + self.counter += 1 + + no_failures() + assert self.counter == 1 + + def test_client_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "MalformedPolicyDocument", + "Message": "Policy document should not specify a principal.", + } + } + + @S3ErrorHandler.list_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "Something bad") + + with pytest.raises(AnsibleS3Error) as e_info: + raise_client_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.ClientError) + assert "do something" in raised.message + assert "Something bad" in str(raised.exception) + + def test_list_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "404", + "Message": "Not found", + } + } + + @S3ErrorHandler.list_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "I couldn't find it") + + ret_val = raise_client_error() + assert self.counter == 1 + assert ret_val is None + + +class TestS3CommonHandler: + def test_no_failures(self): + self.counter = 0 + + @S3ErrorHandler.common_error_handler("no error") + def no_failures(): + self.counter += 1 + + no_failures() + assert self.counter == 1 + + def test_client_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "MalformedPolicyDocument", + "Message": "Policy document should not specify a principal.", + } + } + + @S3ErrorHandler.common_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "Something bad") + + with pytest.raises(AnsibleS3Error) as e_info: + raise_client_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.ClientError) + assert "do something" in raised.message + assert "Something bad" in str(raised.exception) + + def test_permissions_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "AccessDenied", + "Message": "Forbidden", + } + } + + @S3ErrorHandler.common_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "Something bad") + + with pytest.raises(AnsibleS3PermissionsError) as e_info: + raise_client_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.ClientError) + assert "do something" in raised.message + assert "Something bad" in str(raised.exception) + + def test_not_implemented_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "XNotImplemented", + "Message": "The request you provided implies functionality that is not implemented.", + } + } + + @S3ErrorHandler.common_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "Something bad") + + with pytest.raises(AnsibleS3SupportError) as e_info: + raise_client_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.ClientError) + assert "do something" in raised.message + assert "Something bad" in str(raised.exception) + + def test_endpoint_error(self): + self.counter = 0 + + @S3ErrorHandler.common_error_handler("do something") + def raise_connection_error(): + self.counter += 1 + raise botocore.exceptions.EndpointConnectionError(endpoint_url="junk.endpoint") + + with pytest.raises(AnsibleS3Error) as e_info: + raise_connection_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.BotoCoreError) + assert "do something" in raised.message + assert "junk.endpoint" in str(raised.exception) + + def test_sigv4_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "InvalidArgument", + "Message": "Requests specifying Server Side Encryption with AWS KMS managed keys require AWS Signature Version 4", + } + } + + @S3ErrorHandler.common_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "Something bad") + + with pytest.raises(AnsibleS3Sigv4RequiredError) as e_info: + raise_client_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.ClientError) + assert "do something" in raised.message + assert "Something bad" in str(raised.exception) + + def test_boto3_error(self): + self.counter = 0 + err_response = { + "Error": { + "Code": "MalformedPolicyDocument", + "Message": "Policy document should not specify a principal.", + } + } + + @S3ErrorHandler.common_error_handler("do something") + def raise_client_error(): + self.counter += 1 + raise botocore.exceptions.ClientError(err_response, "Something bad") + + with pytest.raises(AnsibleS3Error) as e_info: + raise_client_error() + assert self.counter == 1 + raised = e_info.value + assert isinstance(raised.exception, botocore.exceptions.ClientError) + assert "do something" in raised.message + assert "Something bad" in str(raised.exception) diff --git a/tests/unit/module_utils/s3/test_validate_bucket_name.py b/tests/unit/module_utils/s3/test_validate_bucket_name.py new file mode 100644 index 00000000000..45d935752d0 --- /dev/null +++ b/tests/unit/module_utils/s3/test_validate_bucket_name.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- +# +# (c) 2021 Red Hat Inc. +# +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + +import pytest + +from ansible_collections.amazon.aws.plugins.module_utils import s3 + + +@pytest.mark.parametrize( + "bucket_name,result", + [ + ("docexamplebucket1", None), + ("log-delivery-march-2020", None), + ("my-hosted-content", None), + ("docexamplewebsite.com", None), + ("www.docexamplewebsite.com", None), + ("my.example.s3.bucket", None), + ("doc", None), + ("doc_example_bucket", "invalid character(s) found in the bucket name"), + ("DocExampleBucket", "invalid character(s) found in the bucket name"), + ("doc-example-bucket-", "bucket names must begin and end with a letter or number"), + ( + "this.string.has.more.than.63.characters.so.it.should.not.passed.the.validated", + "the length of an S3 bucket cannot exceed 63 characters", + ), + ("my", "the length of an S3 bucket must be at least 3 characters"), + ], +) +def test_validate_bucket_name(bucket_name, result): + assert result == s3.validate_bucket_name(bucket_name)