Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix parameter assignment in error handling function #1616

Merged
merged 4 commits into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def set_error_code(json_data, return_obj, connector=None):
if error_code == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.debug("failed to map: " + str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_code, connector)
ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector)
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,4 @@ def set_error_code(json_data, return_obj, connector=None):
if error_code == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.debug("failed to map: " + str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_code, connector)
ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector)
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def set_error_code(json_data, return_obj, connector=None):
if error_code == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.debug("failed to map: " + str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_code, connector)
ErrorMapperBase.set_error_code(return_obj, error_code, connector=connector)
88 changes: 63 additions & 25 deletions stix_shifter_modules/crowdstrike/stix_transmission/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,11 @@
from .api_client import APIClient
from stix_shifter_utils.utils.error_response import ErrorResponder
from stix_shifter_utils.utils import logger
from requests.exceptions import ConnectionError


class QueryException(Exception):
pass


class Connector(BaseJsonSyncConnector):
Expand Down Expand Up @@ -31,45 +36,76 @@ def _handle_errors(self, response, return_obj):
"""
response_code = response.code
response_txt = response.read().decode('utf-8')
response_type = response.headers.get('Content-Type')
response_dict = {}

if 200 <= response_code < 300:
return_obj['success'] = True
return_obj['data'] = response_txt
return return_obj
elif ErrorResponder.is_plain_string(response_txt):
ErrorResponder.fill_error(return_obj, message=response_txt)
raise Exception(return_obj)
elif ErrorResponder.is_json_string(response_txt):
response_json = json.loads(response_txt)
ErrorResponder.fill_error(return_obj, response_json, ['reason'], connector=self.connector)
raise Exception(return_obj)
else:
raise Exception(return_obj)
elif response_code >= 400:
if response_type == 'application/json':
error_response = json.loads(response_txt)
response_dict['type'] = 'ValidationError'
response_dict['message'] = error_response['errors'][0]['message']
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
raise QueryException(return_obj)
elif response_type == 'text/html':
error = ConnectionError(f'Error connecting the datasource: {response_txt}')
delliott90 marked this conversation as resolved.
Show resolved Hide resolved
ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector)
raise QueryException(return_obj)
else:
raise Exception(response_txt)


async def ping_connection(self):
response_txt = None
return_obj = {}
response_dict = {}
try:
response = await self.api_client.ping_box()
response_code = response.code
response_txt = response.read().decode('utf-8')
response_type = response.headers.get('Content-Type')
if 199 < response_code < 300:
return_obj['success'] = True
elif isinstance(json.loads(response_txt), dict):
response_error_ping = json.loads(response_txt)
response_dict = response_error_ping['errors'][0]
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
elif response_code == 401:
if response_type == 'application/json':
error_response = json.loads(response_txt)
response_dict['type'] = 'AuthenticationError'
response_dict['message'] = error_response['errors'][0]['message']
self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response))
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
else:
raise Exception(response_txt)
elif response_code == 400:
if response_type == 'application/json':
error_response = json.loads(response_txt)
response_dict['type'] = 'ValidationError'
response_dict['message'] = error_response['errors'][0]['message']
self.logger.error('Error connecting the Crowdstrike datasource: ' + str(error_response))
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
else:
raise Exception(response_txt)
else:
raise Exception(response_txt)
if response_type == 'application/json':
response_error_ping = json.loads(response_txt)
response_dict = response_error_ping['errors'][0]
ErrorResponder.fill_error(return_obj, response_dict, ['message'], connector=self.connector)
elif response_type == 'text/html':
error = ConnectionError(f'Error connecting the datasource: {response_txt}')
delliott90 marked this conversation as resolved.
Show resolved Hide resolved
ErrorResponder.fill_error(return_obj, response_dict, error=error, connector=self.connector)
else:
raise Exception(response_txt)
except Exception as e:
if response_txt is not None:
ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector)
self.logger.error('can not parse response: ' + str(response_txt))
ErrorResponder.fill_error(return_obj, message='unexpected exception: ' + str(response_txt), connector=self.connector)
delliott90 marked this conversation as resolved.
Show resolved Hide resolved
self.logger.error('Can not parse response Crowdstrike error: ' + str(response_txt))
else:
raise e

return return_obj

async def send_info_request_and_handle_errors(self, ids_lst):
return_obj = dict()
response = await self.api_client.get_detections_info(ids_lst)
Expand Down Expand Up @@ -142,7 +178,6 @@ async def create_results_connection(self, query, offset, length):
:param offset: int,offset value
:param length: int,length value"""
result_limit = offset + length
response_txt = None
ids_obj = dict()
return_obj = dict()
table_event_data = []
Expand Down Expand Up @@ -195,10 +230,13 @@ async def create_results_connection(self, query, offset, length):
if not return_obj.get('success'):
return_obj['success'] = True
return return_obj

except QueryException as ex:
return ex.args[0]
except Exception as ex:
if response_txt is not None:
ErrorResponder.fill_error(return_obj, message='unexpected exception', connector=self.connector)
self.logger.error('can not parse response: ' + str(response_txt))
else:
raise ex
error_dict = {}
error_dict['type'] = 'AttributeError'
error_dict['message'] = 'Error while parsing API response: ' + str(ex)
delliott90 marked this conversation as resolved.
Show resolved Hide resolved
ErrorResponder.fill_error(return_obj, error_dict, ['message'], connector=self.connector)
self.logger.error('Unexpected exception from Crowdstrike datasource: ' + str(ex))

return return_obj
36 changes: 36 additions & 0 deletions stix_shifter_modules/crowdstrike/stix_transmission/error_mapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from stix_shifter_utils.utils.error_mapper_base import ErrorMapperBase
from stix_shifter_utils.utils.error_response import ErrorCode
from stix_shifter_utils.utils import logger

error_mapping = {
"ConnectionError": ErrorCode.TRANSMISSION_REMOTE_SYSTEM_IS_UNAVAILABLE,
"AuthenticationError": ErrorCode.TRANSMISSION_AUTH_CREDENTIALS,
"ValidationError": ErrorCode.TRANSMISSION_QUERY_LOGICAL_ERROR,
"AttributeError": ErrorCode.TRANSMISSION_INVALID_PARAMETER,
}


class ErrorMapper:
"""
Set Error Code
"""
logger = logger.set_logger(__name__)
DEFAULT_ERROR = ErrorCode.TRANSMISSION_MODULE_DEFAULT_ERROR

@staticmethod
def set_error_code(json_data, return_obj, connector=None):
err_type = None
try:
err_type = json_data['type']
except KeyError:
pass

error_type = ErrorMapper.DEFAULT_ERROR

if err_type in error_mapping:
error_type = error_mapping.get(err_type)

if error_type == ErrorMapper.DEFAULT_ERROR:
ErrorMapper.logger.error("failed to map: %s", str(json_data))

ErrorMapperBase.set_error_code(return_obj, error_type, connector=connector)
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import json
import unittest
from unittest.mock import ANY
from unittest.mock import patch
from tests.utils.async_utils import get_mock_response
from stix_shifter_modules.crowdstrike.entry_point import EntryPoint
Expand All @@ -18,6 +16,8 @@
'host': 'api.crowdstrike.com'
}

headers = {'Content-Type': 'application/json'}

@patch('stix_shifter_modules.crowdstrike.stix_transmission.api_client.APIClient.get_detections_IDs', autospec=True)
class TestCrowdStrikeConnection(unittest.TestCase, object):

Expand Down Expand Up @@ -54,21 +54,22 @@ def test_create_query_connection(self, mock_api_client):

def test_no_results_response(self, mock_requests_response):
mocked_return_value = """
{"terms": ["process_name:notepad.exe"],
"results": [],
"elapsed": 0.01921701431274414,
"comprehensive_search": true,
"all_segments": true,
"total_results": 0,
"highlights": [],
"facets": {},
"tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode())
{
"terms": ["process_name:notepad.exe"],
"results": [],
"elapsed": 0.01921701431274414,
"comprehensive_search": true,
"all_segments": true,
"total_results": 0,
"highlights": [],
"facets": {},
"tagged_pids": {"00000036-0000-0a02-01d4-97e70c22b346-0167c881d4b3": [{"name": "Default Investigation", "id": 1}, {"name": "Default Investigation", "id": 1}]},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers)

entry_point = EntryPoint(connection, config)
query_expression = self._create_query_list("process_name:notepad.exe")[0]
Expand All @@ -80,30 +81,28 @@ def test_no_results_response(self, mock_requests_response):
assert 'data' in results_response
assert len(results_response['data']) == 0



def test_one_results_response(self, mock_requests_response):
mocked_return_value = """
{
"terms": [
"process_name:cmd.exe",
"start:[2019-01-22T00:00:00 TO *]"
],
"results": [],
"elapsed": 0.05147600173950195,
"comprehensive_search": true,
"all_segments": true,
"total_results": 1,
"highlights": [],
"facets": {},
"tagged_pids": {},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""

mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode())
{
"terms": [
"process_name:cmd.exe",
"start:[2019-01-22T00:00:00 TO *]"
],
"results": [],
"elapsed": 0.05147600173950195,
"comprehensive_search": true,
"all_segments": true,
"total_results": 1,
"highlights": [],
"facets": {},
"tagged_pids": {},
"start": 0,
"incomplete_results": false,
"filtered": {}
}
"""

mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers)

entry_point = EntryPoint(connection, config)
query_expression = self._create_query_list("process_name:cmd.exe start:[2019-01-22 TO *]")[0]
Expand All @@ -115,10 +114,9 @@ def test_one_results_response(self, mock_requests_response):
assert 'data' in results_response
assert len(results_response['data']) == 0


def test_transmit_limit_and_sort(self, mock_requests_response):
mocked_return_value = '{"reason": "query_syntax_error"}'
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode())
mock_requests_response.return_value = get_mock_response(200, mocked_return_value.encode(), headers=headers)

entry_point = EntryPoint(connection, config)
query_expression = self._create_query_list("process_name:cmd.exe")[0]
Expand Down
7 changes: 4 additions & 3 deletions tests/utils/async_utils.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@

def get_mock_response(status_code, content=None, return_type='str', response=None):
return RequestMockResponse(status_code, content, return_type, response)
def get_mock_response(status_code, content=None, return_type='str', response=None, headers=None):
return RequestMockResponse(status_code, content, return_type, response, headers)

def get_aws_mock_response(obj):
return AWSComposeMockResponse(obj)

class RequestMockResponse:
def __init__(self, status_code, content, return_type='str', response=None):
def __init__(self, status_code, content, return_type='str', response=None, headers=None):
self.code = status_code
self.headers = headers
self.content = content
self.response = response
self.object = response
Expand Down
Loading