Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DRAFT] Sync AWS accounts in parallel #1138

Draft
wants to merge 9 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 72 additions & 38 deletions cartography/intel/aws/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
import datetime
import logging
import traceback
from concurrent.futures import ThreadPoolExecutor
from multiprocessing import cpu_count
from typing import Any
from typing import Dict
from typing import Iterable
from typing import List
from typing import Optional
from typing import Tuple

import boto3
import botocore.exceptions
Expand All @@ -15,14 +19,14 @@
from .resources import RESOURCE_FUNCTIONS
from cartography.config import Config
from cartography.intel.aws.util.common import parse_and_validate_aws_requested_syncs
from cartography.neo4j_session_factory import neo4j_session_factory
from cartography.stats import get_stats_client
from cartography.util import merge_module_sync_metadata
from cartography.util import run_analysis_and_ensure_deps
from cartography.util import run_analysis_job
from cartography.util import run_cleanup_job
from cartography.util import timeit


stat_handler = get_stats_client(__name__)
logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -138,6 +142,47 @@ def _autodiscover_accounts(
logger.warning(f"The current account ({account_id}) doesn't have enough permissions to perform autodiscovery.")


def sync_one_account_runner(
profile_name: str,
account_id: str,
boto3_session: boto3.Session,
neo4j_session: neo4j.Session,
sync_tag: int,
common_job_parameters: Dict[str, Any],
aws_requested_syncs: List[str],
aws_best_effort_mode: bool,
) -> Optional[Tuple[str, str]]:
logger.info(f"Syncing AWS account with ID '{account_id}' using configured profile '{profile_name}'.")
common_job_parameters["AWS_ID"] = account_id
_autodiscover_accounts(neo4j_session, boto3_session, account_id, sync_tag, common_job_parameters)
try:
_sync_one_account(
neo4j_session,
boto3_session,
account_id,
sync_tag,
common_job_parameters,
aws_requested_syncs=aws_requested_syncs, # Could be replaced later with per-account requested syncs
)
except Exception as e:
exception_traceback = traceback.TracebackException.from_exception(e)
if aws_best_effort_mode:
timestamp = datetime.datetime.now()
traceback_string = ''.join(exception_traceback.format())
exc_result_string = f'{timestamp} - Exception for account ID: {account_id}\n{traceback_string}'
logger.warning(
f"Caught exception syncing account {account_id}. aws-best-effort-mode is on so we are continuing "
f"on to the next AWS account. All exceptions will be aggregated and re-logged at the end of the "
f"sync.",
exc_info=True,
)
return account_id, exc_result_string
else:
logger.error(f"AWS sync failed for account {account_id}, see traceback; {exception_traceback}")
raise
return None


def _sync_multiple_accounts(
neo4j_session: neo4j.Session,
accounts: Dict[str, str],
Expand All @@ -153,43 +198,32 @@ def _sync_multiple_accounts(
exception_tracebacks = []

num_accounts = len(accounts)

for profile_name, account_id in accounts.items():
logger.info("Syncing AWS account with ID '%s' using configured profile '%s'.", account_id, profile_name)
common_job_parameters["AWS_ID"] = account_id
if num_accounts == 1:
# Use the default boto3 session because boto3 gets confused if you give it a profile name with 1 account
boto3_session = boto3.Session()
else:
boto3_session = boto3.Session(profile_name=profile_name)

_autodiscover_accounts(neo4j_session, boto3_session, account_id, sync_tag, common_job_parameters)

try:
_sync_one_account(
neo4j_session,
boto3_session,
account_id,
sync_tag,
common_job_parameters,
aws_requested_syncs=aws_requested_syncs, # Could be replaced later with per-account requested syncs
)
except Exception as e:
if aws_best_effort_mode:
timestamp = datetime.datetime.now()
failed_account_ids.append(account_id)
exception_traceback = traceback.TracebackException.from_exception(e)
traceback_string = ''.join(exception_traceback.format())
exception_tracebacks.append(f'{timestamp} - Exception for account ID: {account_id}\n{traceback_string}')
logger.warning(
f"Caught exception syncing account {account_id}. aws-best-effort-mode is on so we are continuing "
f"on to the next AWS account. All exceptions will be aggregated and re-logged at the end of the "
f"sync.",
exc_info=True,
)
continue
else:
raise
num_threads = cpu_count()

logger.info(f"AWS: Using {num_threads} threads.")
with ThreadPoolExecutor(max_workers=num_threads) as executor:
for profile_name, account_id in accounts.items():
with neo4j_session_factory.get_new_session() as neo4j_thread_session:
if num_accounts == 1:
# Use the default boto3 session because boto3 gets confused if you give it a profile name w/ 1 acc.
boto3_session = boto3.Session()
else:
boto3_session = boto3.Session(profile_name=profile_name)

failure: Optional[Tuple[str, str]] = executor.submit(
sync_one_account_runner,
profile_name,
account_id,
boto3_session,
neo4j_thread_session,
sync_tag,
common_job_parameters,
aws_requested_syncs,
aws_best_effort_mode,
).result()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this will wait for each execution in turn (consecutive execution) as the call to the result() method of the Future returned by executor.submit will wait for the future to complete.
To get concurrent execution I believe you would want to submit the sync functions in a loop as is done here, collect the returned Future objects in a list and then call concurrent.futures.wait to wait for all of the syncs to be complete (or iterate over concurrent.futures.as_completed if you want to report sync statuses as they complete).

if failure:
failed_account_ids.append(failure[0])
exception_tracebacks.append(failure[1])

if failed_account_ids:
logger.error(f'AWS sync failed for accounts {failed_account_ids}')
Expand Down
38 changes: 38 additions & 0 deletions cartography/neo4j_session_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import logging

import neo4j

logger = logging.getLogger(__name__)


class Neo4jSessionFactory:
_setup = False
_driver = None
_database = None

def __init__(self):
logger.info("neo4j_session_factory init")

def initialize(self, neo4j_driver: neo4j.Driver, neo4j_database: str) -> None:
if self._setup:
logger.warning("Reinitializing the Neo4j session factory is not allowed; doing nothing.")
return

logger.info("Setting up the Neo4j session factory")

self._setup = True
self._driver = neo4j_driver
self._database = neo4j_database

def get_new_session(self) -> neo4j.Session:
if not self._setup or not self._driver:
raise RuntimeError(
"Neo4j session factory is not initialized. "
"Make sure that initialize() is called before get_new_session().",
)

new_session = self._driver.session(database=self._database)
return new_session


neo4j_session_factory = Neo4jSessionFactory()
4 changes: 4 additions & 0 deletions cartography/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import cartography.intel.oci
import cartography.intel.okta
from cartography.config import Config
from cartography.neo4j_session_factory import neo4j_session_factory
from cartography.stats import set_stats_client
from cartography.util import STATUS_FAILURE
from cartography.util import STATUS_SUCCESS
Expand Down Expand Up @@ -78,6 +79,9 @@ def run(self, neo4j_driver: neo4j.Driver, config: Union[Config, argparse.Namespa
:param config: Configuration for the sync run.
"""
logger.info("Starting sync with update tag '%d'", config.update_tag)

neo4j_session_factory.initialize(neo4j_driver=neo4j_driver, neo4j_database=config.neo4j_database)

with neo4j_driver.session(database=config.neo4j_database) as neo4j_session:
for stage_name, stage_func in self._stages.items():
logger.info("Starting sync stage '%s'", stage_name)
Expand Down
60 changes: 60 additions & 0 deletions tests/unit/cartography/test_neo4j_session_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
import unittest
from unittest import mock

import neo4j
import pytest

from cartography.neo4j_session_factory import Neo4jSessionFactory


def test_initialize():
# Arrange
neo4j_session_factory = Neo4jSessionFactory()
neo4j_driver_mock = mock.Mock(spec=neo4j.Driver)

# Act
neo4j_session_factory.initialize(neo4j_driver_mock, "test_db")

# Assert
assert neo4j_session_factory._driver == neo4j_driver_mock
assert neo4j_session_factory._database == "test_db"


def test_get_new_session():
# Arrange
neo4j_session_factory = Neo4jSessionFactory()
neo4j_driver_mock = mock.Mock(spec=neo4j.Driver)
neo4j_session_factory.initialize(neo4j_driver_mock, "test_db")
neo4j_session_mock = mock.Mock()
neo4j_driver_mock.session.return_value = neo4j_session_mock

# Act
new_session = neo4j_session_factory.get_new_session()

# Assert
assert new_session == neo4j_session_mock


class TestNeo4jSessionFactory(unittest.TestCase):
def setUp(self):
self.driver_mock = mock.Mock(spec=neo4j.Driver)

def test_reinitialize(self):
# Arrange
neo4j_session_factory = Neo4jSessionFactory()
neo4j_session_factory.initialize(self.driver_mock, "test_db")

# Act
with self.assertLogs(level="WARNING") as log:
neo4j_session_factory.initialize(self.driver_mock, "test_db")

# Assert
self.assertIn("Reinitializing the Neo4j session", log.output[0])


def test_neo4j_session_factory_get_new_session_not_initialized():
neo4j_session_factory = Neo4jSessionFactory()

with pytest.raises(RuntimeError, match="Neo4j session factory is not initialized"):
new_session = neo4j_session_factory.get_new_session()
assert new_session is None