Skip to content

Commit

Permalink
Fix parameter assignment in error handling function (#1616)
Browse files Browse the repository at this point in the history
  • Loading branch information
mdazam1942 authored Nov 22, 2023
1 parent 3e8b7b0 commit 02d45de
Show file tree
Hide file tree
Showing 7 changed files with 145 additions and 72 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def set_error_code(json_data, return_obj, connector=None):
if error_code == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.debug("failed to map: " + str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_code, connector)
ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector)
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ def set_error_code(json_data, return_obj, connector=None):
if error_code == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.debug("failed to map: " + str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_code, connector)
ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector)
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def set_error_code(json_data, return_obj, connector=None):
if error_code == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.debug("failed to map: " + str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_code, connector)
ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector)
88 changes: 63 additions & 25 deletions stix_shifter_modules/crowdstrike/stix_transmission/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from .api_client import APIClient
from stix_shifter_utils.utils.error_response import ErrorResponder
from stix_shifter_utils.utils import logger
from requests.exceptions import ConnectionError


class QueryException(Exception):
pass


class Connector(BaseJsonSyncConnector):
Expand Down Expand Up @@ -31,45 +36,76 @@ def _handle_errors(self, response, return_obj):
"""
response_code = response.code
response_txt = response.read().decode('utf-8')
response_type = response.headers.get('Content-Type')
response_dict = {}

if 200 <= response_code < 300:
return_obj['success'] = True
return_obj['data'] = response_txt
return return_obj
elif ErrorResponder.is_plain_string(response_txt):
ErrorResponder.fill_error(return_obj, message=response_txt)
raise Exception(return_obj)
elif ErrorResponder.is_json_string(response_txt):
response_json = json.loads(response_txt)
ErrorResponder.fill_error(return_obj, response_json, ['reason'], connector=self.connector)
raise Exception(return_obj)
else:
raise Exception(return_obj)
elif response_code >= 400:
if response_type == 'application/json':
error_response = json.loads(response_txt)
response_dict['type'] = 'ValidationError'
response_dict['message'] = error_response['errors'][0]['message']
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
raise QueryException(return_obj)
elif response_type == 'text/html':
error = ConnectionError(f'Error connecting the datasource: {response_txt}')
ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector)
raise QueryException(return_obj)
else:
raise Exception(response_txt)


async def ping_connection(self):
response_txt = None
return_obj = {}
response_dict = {}
try:
response = await self.api_client.ping_box()
response_code = response.code
response_txt = response.read().decode('utf-8')
response_type = response.headers.get('Content-Type')
if 199 < response_code < 300:
return_obj['success'] = True
elif isinstance(json.loads(response_txt), dict):
response_error_ping = json.loads(response_txt)
response_dict = response_error_ping['errors'][0]
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
elif response_code == 401:
if response_type == 'application/json':
error_response = json.loads(response_txt)
response_dict['type'] = 'AuthenticationError'
response_dict['message'] = error_response['errors'][0]['message']
self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response))
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
else:
raise Exception(response_txt)
elif response_code == 400:
if response_type == 'application/json':
error_response = json.loads(response_txt)
response_dict['type'] = 'ValidationError'
response_dict['message'] = error_response['errors'][0]['message']
self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response))
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
else:
raise Exception(response_txt)
else:
raise Exception(response_txt)
if response_type == 'application/json':
response_error_ping = json.loads(response_txt)
response_dict = response_error_ping['errors'][0]
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
elif response_type == 'text/html':
error = ConnectionError(f'Error connecting the datasource: {response_txt}')
ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector)
else:
raise Exception(response_txt)
except Exception as e:
if response_txt is not None:
ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector)
self.logger.error('can not parse response: ' + str(response_txt))
ErrorResponder.fill_error(return_obj, message='unexpected exception: ' + str(response_txt), connector=self.connector)
self.logger.error('Can not parse response Crowdstrike error: ' + str(response_txt))
else:
raise e

return return_obj

async def send_info_request_and_handle_errors(self, ids_lst):
return_obj = dict()
response = await self.api_client.get_detections_info(ids_lst)
Expand Down Expand Up @@ -142,7 +178,6 @@ async def create_results_connection(self, query, offset, length):
:param offset: int,offset value
:param length: int,length value"""
result_limit = offset + length
response_txt = None
ids_obj = dict()
return_obj = dict()
table_event_data = []
Expand Down Expand Up @@ -195,10 +230,13 @@ async def create_results_connection(self, query, offset, length):
if not return_obj.get('success'):
return_obj['success'] = True
return return_obj

except QueryException as ex:
return ex.args[0]
except Exception as ex:
if response_txt is not None:
ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector)
self.logger.error('can not parse response: ' + str(response_txt))
else:
raise ex
error_dict = {}
error_dict['type'] = 'AttributeError'
error_dict['message'] = 'Error while parsing API response: ' + str(ex)
ErrorResponder.fill_error(return_obj, error_dict, ['message'], connector=self.connector)
self.logger.error('Unexpected exception from Crowdstrike datasource: ' + str(ex))

return return_obj
36 changes: 36 additions & 0 deletions stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from stix_shifter_utils.utils.error_mapper_base import ErrorMapperBase
from stix_shifter_utils.utils.error_response import ErrorCode
from stix_shifter_utils.utils import logger

error_mapping = {
"ConnectionError": ErrorCode.TRANSMISSION_REMOTE_SYSTEM_IS_UNAVAILABLE,
"AuthenticationError": ErrorCode.TRANSMISSION_AUTH_CREDENTIALS,
"ValidationError": ErrorCode.TRANSMISSION_QUERY_LOGICAL_ERROR,
"AttributeError": ErrorCode.TRANSMISSION_INVALID_PARAMETER,
}


class ErrorMapper:
"""
Set Error Code
"""
logger = logger.set_logger(__name__)
DEFAULT_ERROR = ErrorCode.TRANSMISSION_MODULE_DEFAULT_ERROR

@staticmethod
def set_error_code(json_data, return_obj, connector=None):
err_type = None
try:
err_type = json_data['type']
except KeyError:
pass

error_type = ErrorMapper.DEFAULT_ERROR

if err_type in error_mapping:
error_type = error_mapping.get(err_type)

if error_type == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.error("failed to map: %s", str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_type, connector=connector)
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import unittest
from unittest.mock import ANY
from unittest.mock import patch
from tests.utils.async_utils import get_mock_response
from stix_shifter_modules.crowdstrike.entry_point import EntryPoint
Expand All @@ -18,6 +16,8 @@
'host': 'api.crowdstrike.com'
}

headers = {'Content-Type': 'application/json'}

@patch('stix_shifter_modules.crowdstrike.stix_transmission.api_client.APIClient.get_detections_IDs', autospec=True)
class TestCrowdStrikeConnection(unittest.TestCase, object):

Expand Down Expand Up @@ -54,21 +54,22 @@ def test_create_query_connection(self, mock_api_client):

def test_no_results_response(self, mock_requests_response):
mocked_return_value = """
{"terms": ["process_name:notepad.exe"],
"results": [],
"elapsed": 0.01921701431274414,
"comprehensive_search": true,
"all_segments": true,
"total_results": 0,
"highlights": [],
"facets": {},
"tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode())
{
"terms": ["process_name:notepad.exe"],
"results": [],
"elapsed": 0.01921701431274414,
"comprehensive_search": true,
"all_segments": true,
"total_results": 0,
"highlights": [],
"facets": {},
"tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers)

entry_point = EntryPoint(connection, config)
query_expression = self._create_query_list("process_name:notepad.exe")[0]
Expand All @@ -80,30 +81,28 @@ def test_no_results_response(self, mock_requests_response):
assert 'data' in results_response
assert len(results_response['data']) == 0



def test_one_results_response(self, mock_requests_response):
mocked_return_value = """
{
"terms": [
"process_name:cmd.exe",
"start:[2019-01-22T00:00:00 TO *]"
],
"results": [],
"elapsed": 0.05147600173950195,
"comprehensive_search": true,
"all_segments": true,
"total_results": 1,
"highlights": [],
"facets": {},
"tagged_pids": {},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""

mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode())
{
"terms": [
"process_name:cmd.exe",
"start:[2019-01-22T00:00:00 TO *]"
],
"results": [],
"elapsed": 0.05147600173950195,
"comprehensive_search": true,
"all_segments": true,
"total_results": 1,
"highlights": [],
"facets": {},
"tagged_pids": {},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""

mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers)

entry_point = EntryPoint(connection, config)
query_expression = self._create_query_list("process_name:cmd.exe start:[2019-01-22 TO *]")[0]
Expand All @@ -115,10 +114,9 @@ def test_one_results_response(self, mock_requests_response):
assert 'data' in results_response
assert len(results_response['data']) == 0


def test_transmit_limit_and_sort(self, mock_requests_response):
mocked_return_value = '{"reason": "query_syntax_error"}'
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode())
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers)

entry_point = EntryPoint(connection, config)
query_expression = self._create_query_list("process_name:cmd.exe")[0]
Expand Down
7 changes: 4 additions & 3 deletions tests/utils/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@

def get_mock_response(status_code, content=None, return_type='str', response=None):
return RequestMockResponse(status_code, content, return_type, response)
def get_mock_response(status_code, content=None, return_type='str', response=None, headers=None):
return RequestMockResponse(status_code, content, return_type, response, headers)

def get_aws_mock_response(obj):
return AWSComposeMockResponse(obj)

class RequestMockResponse:
def __init__(self, status_code, content, return_type='str', response=None):
def __init__(self, status_code, content, return_type='str', response=None, headers=None):
self.code = status_code
self.headers = headers
self.content = content
self.response = response
self.object = response
Expand Down

0 comments on commit 02d45de

Please sign in to comment.