From 02d45de3721d6f0b01d9ec6f62ce099c42218dbe Mon Sep 17 00:00:00 2001 From: Md Azam Date: Wed, 22 Nov 2023 09:30:15 -0400 Subject: [PATCH] Fix parameter assignment in error handling function (#1616) --- .../stix_transmission/error_mapper.py | 2 +- .../stix_transmission/error_mapper.py | 2 +- .../stix_transmission/error_mapper.py | 2 +- .../stix_transmission/connector.py | 88 +++++++++++++------ .../stix_transmission/error_mapper.py | 36 ++++++++ .../stix_transmission/test_crowdstrike.py | 80 ++++++++--------- tests/utils/async_utils.py | 7 +- 7 files changed, 145 insertions(+), 72 deletions(-) create mode 100644 stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py diff --git a/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py b/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py index 05dfa712a..209c52843 100644 --- a/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py +++ b/stix_shifter_modules/arcsight/stix_transmission/error_mapper.py @@ -45,4 +45,4 @@ def set_error_code(json_data, return_obj, connector=None): if error_code == ErrorMapper.DEFAULT_ERROR: ErrorMapper.logger.debug("failed to map: " + str(json_data)) - ErrorMapperBase.set_error_code(return_obj, error_code, connector) + ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector) diff --git a/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py b/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py index 83a08dc33..7ef001841 100644 --- a/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py +++ b/stix_shifter_modules/aws_athena/stix_transmission/error_mapper.py @@ -43,4 +43,4 @@ def set_error_code(json_data, return_obj, connector=None): if error_code == ErrorMapper.DEFAULT_ERROR: ErrorMapper.logger.debug("failed to map: " + str(json_data)) - ErrorMapperBase.set_error_code(return_obj, error_code, connector) + ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector) diff --git a/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py b/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py index 0ff1c75b4..7c73ce608 100644 --- a/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py +++ b/stix_shifter_modules/aws_cloud_watch_logs/stix_transmission/error_mapper.py @@ -44,4 +44,4 @@ def set_error_code(json_data, return_obj, connector=None): if error_code == ErrorMapper.DEFAULT_ERROR: ErrorMapper.logger.debug("failed to map: " + str(json_data)) - ErrorMapperBase.set_error_code(return_obj, error_code, connector) + ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector) diff --git a/stix_shifter_modules/crowdstrike/stix_transmission/connector.py b/stix_shifter_modules/crowdstrike/stix_transmission/connector.py index 71347b596..b2b532358 100644 --- a/stix_shifter_modules/crowdstrike/stix_transmission/connector.py +++ b/stix_shifter_modules/crowdstrike/stix_transmission/connector.py @@ -3,6 +3,11 @@ from .api_client import APIClient from stix_shifter_utils.utils.error_response import ErrorResponder from stix_shifter_utils.utils import logger +from requests.exceptions import ConnectionError + + +class QueryException(Exception): + pass class Connector(BaseJsonSyncConnector): @@ -31,45 +36,76 @@ def _handle_errors(self, response, return_obj): """ response_code = response.code response_txt = response.read().decode('utf-8') + response_type = response.headers.get('Content-Type') + response_dict = {} if 200 <= response_code < 300: return_obj['success'] = True return_obj['data'] = response_txt return return_obj - elif ErrorResponder.is_plain_string(response_txt): - ErrorResponder.fill_error(return_obj, message=response_txt) - raise Exception(return_obj) - elif ErrorResponder.is_json_string(response_txt): - response_json = json.loads(response_txt) - ErrorResponder.fill_error(return_obj, response_json, ['reason'], connector=self.connector) - raise Exception(return_obj) - else: - raise Exception(return_obj) + elif response_code >= 400: + if response_type == 'application/json': + error_response = json.loads(response_txt) + response_dict['type'] = 'ValidationError' + response_dict['message'] = error_response['errors'][0]['message'] + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + raise QueryException(return_obj) + elif response_type == 'text/html': + error = ConnectionError(f'Error connecting the datasource: {response_txt}') + ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector) + raise QueryException(return_obj) + else: + raise Exception(response_txt) + async def ping_connection(self): response_txt = None return_obj = {} + response_dict = {} try: response = await self.api_client.ping_box() response_code = response.code response_txt = response.read().decode('utf-8') + response_type = response.headers.get('Content-Type') if 199 < response_code < 300: return_obj['success'] = True - elif isinstance(json.loads(response_txt), dict): - response_error_ping = json.loads(response_txt) - response_dict = response_error_ping['errors'][0] - ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + elif response_code == 401: + if response_type == 'application/json': + error_response = json.loads(response_txt) + response_dict['type'] = 'AuthenticationError' + response_dict['message'] = error_response['errors'][0]['message'] + self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response)) + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + else: + raise Exception(response_txt) + elif response_code == 400: + if response_type == 'application/json': + error_response = json.loads(response_txt) + response_dict['type'] = 'ValidationError' + response_dict['message'] = error_response['errors'][0]['message'] + self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response)) + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + else: + raise Exception(response_txt) else: - raise Exception(response_txt) + if response_type == 'application/json': + response_error_ping = json.loads(response_txt) + response_dict = response_error_ping['errors'][0] + ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector) + elif response_type == 'text/html': + error = ConnectionError(f'Error connecting the datasource: {response_txt}') + ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector) + else: + raise Exception(response_txt) except Exception as e: if response_txt is not None: - ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector) - self.logger.error('can not parse response: ' + str(response_txt)) + ErrorResponder.fill_error(return_obj, message='unexpected exception: ' + str(response_txt), connector=self.connector) + self.logger.error('Can not parse response Crowdstrike error: ' + str(response_txt)) else: raise e - + return return_obj - + async def send_info_request_and_handle_errors(self, ids_lst): return_obj = dict() response = await self.api_client.get_detections_info(ids_lst) @@ -142,7 +178,6 @@ async def create_results_connection(self, query, offset, length): :param offset: int,offset value :param length: int,length value""" result_limit = offset + length - response_txt = None ids_obj = dict() return_obj = dict() table_event_data = [] @@ -195,10 +230,13 @@ async def create_results_connection(self, query, offset, length): if not return_obj.get('success'): return_obj['success'] = True return return_obj - + except QueryException as ex: + return ex.args[0] except Exception as ex: - if response_txt is not None: - ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector) - self.logger.error('can not parse response: ' + str(response_txt)) - else: - raise ex + error_dict = {} + error_dict['type'] = 'AttributeError' + error_dict['message'] = 'Error while parsing API response: ' + str(ex) + ErrorResponder.fill_error(return_obj, error_dict, ['message'], connector=self.connector) + self.logger.error('Unexpected exception from Crowdstrike datasource: ' + str(ex)) + + return return_obj diff --git a/stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py b/stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py new file mode 100644 index 000000000..703d3912d --- /dev/null +++ b/stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py @@ -0,0 +1,36 @@ +from stix_shifter_utils.utils.error_mapper_base import ErrorMapperBase +from stix_shifter_utils.utils.error_response import ErrorCode +from stix_shifter_utils.utils import logger + +error_mapping = { + "ConnectionError": ErrorCode.TRANSMISSION_REMOTE_SYSTEM_IS_UNAVAILABLE, + "AuthenticationError": ErrorCode.TRANSMISSION_AUTH_CREDENTIALS, + "ValidationError": ErrorCode.TRANSMISSION_QUERY_LOGICAL_ERROR, + "AttributeError": ErrorCode.TRANSMISSION_INVALID_PARAMETER, +} + + +class ErrorMapper: + """ + Set Error Code + """ + logger = logger.set_logger(__name__) + DEFAULT_ERROR = ErrorCode.TRANSMISSION_MODULE_DEFAULT_ERROR + + @staticmethod + def set_error_code(json_data, return_obj, connector=None): + err_type = None + try: + err_type = json_data['type'] + except KeyError: + pass + + error_type = ErrorMapper.DEFAULT_ERROR + + if err_type in error_mapping: + error_type = error_mapping.get(err_type) + + if error_type == ErrorMapper.DEFAULT_ERROR: + ErrorMapper.logger.error("failed to map: %s", str(json_data)) + + ErrorMapperBase.set_error_code(return_obj, error_type, connector=connector) diff --git a/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py b/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py index 6503f799a..b582bc8e0 100644 --- a/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py +++ b/stix_shifter_modules/crowdstrike/tests/stix_transmission/test_crowdstrike.py @@ -1,6 +1,4 @@ -import json import unittest -from unittest.mock import ANY from unittest.mock import patch from tests.utils.async_utils import get_mock_response from stix_shifter_modules.crowdstrike.entry_point import EntryPoint @@ -18,6 +16,8 @@ 'host': 'api.crowdstrike.com' } +headers = {'Content-Type': 'application/json'} + @patch('stix_shifter_modules.crowdstrike.stix_transmission.api_client.APIClient.get_detections_IDs', autospec=True) class TestCrowdStrikeConnection(unittest.TestCase, object): @@ -54,21 +54,22 @@ def test_create_query_connection(self, mock_api_client): def test_no_results_response(self, mock_requests_response): mocked_return_value = """ -{"terms": ["process_name:notepad.exe"], - "results": [], - "elapsed": 0.01921701431274414, - "comprehensive_search": true, - "all_segments": true, - "total_results": 0, - "highlights": [], - "facets": {}, - "tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]}, - "start": 0, - "incomplete_results": false, - "filtered": {} -} -""" - mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode()) + { + "terms": ["process_name:notepad.exe"], + "results": [], + "elapsed": 0.01921701431274414, + "comprehensive_search": true, + "all_segments": true, + "total_results": 0, + "highlights": [], + "facets": {}, + "tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]}, + "start": 0, + "incomplete_results": false, + "filtered": {} + } + """ + mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers) entry_point = EntryPoint(connection, config) query_expression = self._create_query_list("process_name:notepad.exe")[0] @@ -80,30 +81,28 @@ def test_no_results_response(self, mock_requests_response): assert 'data' in results_response assert len(results_response['data']) == 0 - - def test_one_results_response(self, mock_requests_response): mocked_return_value = """ -{ - "terms": [ - "process_name:cmd.exe", - "start:[2019-01-22T00:00:00 TO *]" - ], - "results": [], - "elapsed": 0.05147600173950195, - "comprehensive_search": true, - "all_segments": true, - "total_results": 1, - "highlights": [], - "facets": {}, - "tagged_pids": {}, - "start": 0, - "incomplete_results": false, - "filtered": {} -} -""" - - mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode()) + { + "terms": [ + "process_name:cmd.exe", + "start:[2019-01-22T00:00:00 TO *]" + ], + "results": [], + "elapsed": 0.05147600173950195, + "comprehensive_search": true, + "all_segments": true, + "total_results": 1, + "highlights": [], + "facets": {}, + "tagged_pids": {}, + "start": 0, + "incomplete_results": false, + "filtered": {} + } + """ + + mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers) entry_point = EntryPoint(connection, config) query_expression = self._create_query_list("process_name:cmd.exe start:[2019-01-22 TO *]")[0] @@ -115,10 +114,9 @@ def test_one_results_response(self, mock_requests_response): assert 'data' in results_response assert len(results_response['data']) == 0 - def test_transmit_limit_and_sort(self, mock_requests_response): mocked_return_value = '{"reason": "query_syntax_error"}' - mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode()) + mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers) entry_point = EntryPoint(connection, config) query_expression = self._create_query_list("process_name:cmd.exe")[0] diff --git a/tests/utils/async_utils.py b/tests/utils/async_utils.py index 61ae5c7f3..aafef083f 100644 --- a/tests/utils/async_utils.py +++ b/tests/utils/async_utils.py @@ -1,13 +1,14 @@ -def get_mock_response(status_code, content=None, return_type='str', response=None): - return RequestMockResponse(status_code, content, return_type, response) +def get_mock_response(status_code, content=None, return_type='str', response=None, headers=None): + return RequestMockResponse(status_code, content, return_type, response, headers) def get_aws_mock_response(obj): return AWSComposeMockResponse(obj) class RequestMockResponse: - def __init__(self, status_code, content, return_type='str', response=None): + def __init__(self, status_code, content, return_type='str', response=None, headers=None): self.code = status_code + self.headers = headers self.content = content self.response = response self.object = response