diff --git a/cartography/intel/aws/__init__.py b/cartography/intel/aws/__init__.py index 8ae0ccdd8..2a0209602 100644 --- a/cartography/intel/aws/__init__.py +++ b/cartography/intel/aws/__init__.py @@ -1,10 +1,14 @@ import datetime import logging import traceback +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import cpu_count from typing import Any from typing import Dict from typing import Iterable from typing import List +from typing import Optional +from typing import Tuple import boto3 import botocore.exceptions @@ -15,6 +19,7 @@ from .resources import RESOURCE_FUNCTIONS from cartography.config import Config from cartography.intel.aws.util.common import parse_and_validate_aws_requested_syncs +from cartography.neo4j_session_factory import neo4j_session_factory from cartography.stats import get_stats_client from cartography.util import merge_module_sync_metadata from cartography.util import run_analysis_and_ensure_deps @@ -22,7 +27,6 @@ from cartography.util import run_cleanup_job from cartography.util import timeit - stat_handler = get_stats_client(__name__) logger = logging.getLogger(__name__) @@ -138,6 +142,47 @@ def _autodiscover_accounts( logger.warning(f"The current account ({account_id}) doesn't have enough permissions to perform autodiscovery.") +def sync_one_account_runner( + profile_name: str, + account_id: str, + boto3_session: boto3.Session, + neo4j_session: neo4j.Session, + sync_tag: int, + common_job_parameters: Dict[str, Any], + aws_requested_syncs: List[str], + aws_best_effort_mode: bool, +) -> Optional[Tuple[str, str]]: + logger.info(f"Syncing AWS account with ID '{account_id}' using configured profile '{profile_name}'.") + common_job_parameters["AWS_ID"] = account_id + _autodiscover_accounts(neo4j_session, boto3_session, account_id, sync_tag, common_job_parameters) + try: + _sync_one_account( + neo4j_session, + boto3_session, + account_id, + sync_tag, + common_job_parameters, + aws_requested_syncs=aws_requested_syncs, # Could be replaced later with per-account requested syncs + ) + except Exception as e: + exception_traceback = traceback.TracebackException.from_exception(e) + if aws_best_effort_mode: + timestamp = datetime.datetime.now() + traceback_string = ''.join(exception_traceback.format()) + exc_result_string = f'{timestamp} - Exception for account ID: {account_id}\n{traceback_string}' + logger.warning( + f"Caught exception syncing account {account_id}. aws-best-effort-mode is on so we are continuing " + f"on to the next AWS account. All exceptions will be aggregated and re-logged at the end of the " + f"sync.", + exc_info=True, + ) + return account_id, exc_result_string + else: + logger.error(f"AWS sync failed for account {account_id}, see traceback; {exception_traceback}") + raise + return None + + def _sync_multiple_accounts( neo4j_session: neo4j.Session, accounts: Dict[str, str], @@ -153,43 +198,32 @@ def _sync_multiple_accounts( exception_tracebacks = [] num_accounts = len(accounts) - - for profile_name, account_id in accounts.items(): - logger.info("Syncing AWS account with ID '%s' using configured profile '%s'.", account_id, profile_name) - common_job_parameters["AWS_ID"] = account_id - if num_accounts == 1: - # Use the default boto3 session because boto3 gets confused if you give it a profile name with 1 account - boto3_session = boto3.Session() - else: - boto3_session = boto3.Session(profile_name=profile_name) - - _autodiscover_accounts(neo4j_session, boto3_session, account_id, sync_tag, common_job_parameters) - - try: - _sync_one_account( - neo4j_session, - boto3_session, - account_id, - sync_tag, - common_job_parameters, - aws_requested_syncs=aws_requested_syncs, # Could be replaced later with per-account requested syncs - ) - except Exception as e: - if aws_best_effort_mode: - timestamp = datetime.datetime.now() - failed_account_ids.append(account_id) - exception_traceback = traceback.TracebackException.from_exception(e) - traceback_string = ''.join(exception_traceback.format()) - exception_tracebacks.append(f'{timestamp} - Exception for account ID: {account_id}\n{traceback_string}') - logger.warning( - f"Caught exception syncing account {account_id}. aws-best-effort-mode is on so we are continuing " - f"on to the next AWS account. All exceptions will be aggregated and re-logged at the end of the " - f"sync.", - exc_info=True, - ) - continue - else: - raise + num_threads = cpu_count() + + logger.info(f"AWS: Using {num_threads} threads.") + with ThreadPoolExecutor(max_workers=num_threads) as executor: + for profile_name, account_id in accounts.items(): + with neo4j_session_factory.get_new_session() as neo4j_thread_session: + if num_accounts == 1: + # Use the default boto3 session because boto3 gets confused if you give it a profile name w/ 1 acc. + boto3_session = boto3.Session() + else: + boto3_session = boto3.Session(profile_name=profile_name) + + failure: Optional[Tuple[str, str]] = executor.submit( + sync_one_account_runner, + profile_name, + account_id, + boto3_session, + neo4j_thread_session, + sync_tag, + common_job_parameters, + aws_requested_syncs, + aws_best_effort_mode, + ).result() + if failure: + failed_account_ids.append(failure[0]) + exception_tracebacks.append(failure[1]) if failed_account_ids: logger.error(f'AWS sync failed for accounts {failed_account_ids}') diff --git a/cartography/neo4j_session_factory.py b/cartography/neo4j_session_factory.py new file mode 100644 index 000000000..23c5a8bef --- /dev/null +++ b/cartography/neo4j_session_factory.py @@ -0,0 +1,38 @@ +import logging + +import neo4j + +logger = logging.getLogger(__name__) + + +class Neo4jSessionFactory: + _setup = False + _driver = None + _database = None + + def __init__(self): + logger.info("neo4j_session_factory init") + + def initialize(self, neo4j_driver: neo4j.Driver, neo4j_database: str) -> None: + if self._setup: + logger.warning("Reinitializing the Neo4j session factory is not allowed; doing nothing.") + return + + logger.info("Setting up the Neo4j session factory") + + self._setup = True + self._driver = neo4j_driver + self._database = neo4j_database + + def get_new_session(self) -> neo4j.Session: + if not self._setup or not self._driver: + raise RuntimeError( + "Neo4j session factory is not initialized. " + "Make sure that initialize() is called before get_new_session().", + ) + + new_session = self._driver.session(database=self._database) + return new_session + + +neo4j_session_factory = Neo4jSessionFactory() diff --git a/cartography/sync.py b/cartography/sync.py index 4ac02593c..53806a535 100644 --- a/cartography/sync.py +++ b/cartography/sync.py @@ -26,6 +26,7 @@ import cartography.intel.oci import cartography.intel.okta from cartography.config import Config +from cartography.neo4j_session_factory import neo4j_session_factory from cartography.stats import set_stats_client from cartography.util import STATUS_FAILURE from cartography.util import STATUS_SUCCESS @@ -78,6 +79,9 @@ def run(self, neo4j_driver: neo4j.Driver, config: Union[Config, argparse.Namespa :param config: Configuration for the sync run. """ logger.info("Starting sync with update tag '%d'", config.update_tag) + + neo4j_session_factory.initialize(neo4j_driver=neo4j_driver, neo4j_database=config.neo4j_database) + with neo4j_driver.session(database=config.neo4j_database) as neo4j_session: for stage_name, stage_func in self._stages.items(): logger.info("Starting sync stage '%s'", stage_name) diff --git a/tests/unit/cartography/test_neo4j_session_factory.py b/tests/unit/cartography/test_neo4j_session_factory.py new file mode 100644 index 000000000..95b832ba3 --- /dev/null +++ b/tests/unit/cartography/test_neo4j_session_factory.py @@ -0,0 +1,60 @@ +import unittest +from unittest import mock + +import neo4j +import pytest + +from cartography.neo4j_session_factory import Neo4jSessionFactory + + +def test_initialize(): + # Arrange + neo4j_session_factory = Neo4jSessionFactory() + neo4j_driver_mock = mock.Mock(spec=neo4j.Driver) + + # Act + neo4j_session_factory.initialize(neo4j_driver_mock, "test_db") + + # Assert + assert neo4j_session_factory._driver == neo4j_driver_mock + assert neo4j_session_factory._database == "test_db" + + +def test_get_new_session(): + # Arrange + neo4j_session_factory = Neo4jSessionFactory() + neo4j_driver_mock = mock.Mock(spec=neo4j.Driver) + neo4j_session_factory.initialize(neo4j_driver_mock, "test_db") + neo4j_session_mock = mock.Mock() + neo4j_driver_mock.session.return_value = neo4j_session_mock + + # Act + new_session = neo4j_session_factory.get_new_session() + + # Assert + assert new_session == neo4j_session_mock + + +class TestNeo4jSessionFactory(unittest.TestCase): + def setUp(self): + self.driver_mock = mock.Mock(spec=neo4j.Driver) + + def test_reinitialize(self): + # Arrange + neo4j_session_factory = Neo4jSessionFactory() + neo4j_session_factory.initialize(self.driver_mock, "test_db") + + # Act + with self.assertLogs(level="WARNING") as log: + neo4j_session_factory.initialize(self.driver_mock, "test_db") + + # Assert + self.assertIn("Reinitializing the Neo4j session", log.output[0]) + + +def test_neo4j_session_factory_get_new_session_not_initialized(): + neo4j_session_factory = Neo4jSessionFactory() + + with pytest.raises(RuntimeError, match="Neo4j session factory is not initialized"): + new_session = neo4j_session_factory.get_new_session() + assert new_session is None