From 11120938e6d4ddecacce83ef1ee349982d6864b4 Mon Sep 17 00:00:00 2001 From: Alex Chantavy Date: Tue, 14 Mar 2023 16:47:03 -0700 Subject: [PATCH 1/9] Initial commit for parallel account sync --- cartography/intel/aws/__init__.py | 95 ++++++++++++++++++++----------- 1 file changed, 63 insertions(+), 32 deletions(-) diff --git a/cartography/intel/aws/__init__.py b/cartography/intel/aws/__init__.py index 8ae0ccdd8..5a66bbbe7 100644 --- a/cartography/intel/aws/__init__.py +++ b/cartography/intel/aws/__init__.py @@ -1,10 +1,14 @@ import datetime import logging import traceback +from concurrent.futures import ThreadPoolExecutor +from multiprocessing import cpu_count from typing import Any from typing import Dict from typing import Iterable from typing import List +from typing import Optional +from typing import Tuple import boto3 import botocore.exceptions @@ -138,6 +142,47 @@ def _autodiscover_accounts( logger.warning(f"The current account ({account_id}) doesn't have enough permissions to perform autodiscovery.") +def sync_one_account_runner( + profile_name: str, + account_id: str, + boto3_session: boto3.Session, + neo4j_session: neo4j.Session, + sync_tag: int, + common_job_parameters: Dict[str, Any], + aws_requested_syncs: List[str], + aws_best_effort_mode: bool, +) -> Optional[Tuple[str, str]]: + logger.info(f"Syncing AWS account with ID '{account_id}' using configured profile '{profile_name}'.") + common_job_parameters["AWS_ID"] = account_id + _autodiscover_accounts(neo4j_session, boto3_session, account_id, sync_tag, common_job_parameters) + try: + _sync_one_account( + neo4j_session, + boto3_session, + account_id, + sync_tag, + common_job_parameters, + aws_requested_syncs=aws_requested_syncs, # Could be replaced later with per-account requested syncs + ) + except Exception as e: + exception_traceback = traceback.TracebackException.from_exception(e) + if aws_best_effort_mode: + timestamp = datetime.datetime.now() + traceback_string = ''.join(exception_traceback.format()) + exc_result_string = f'{timestamp} - Exception for account ID: {account_id}\n{traceback_string}' + logger.warning( + f"Caught exception syncing account {account_id}. aws-best-effort-mode is on so we are continuing " + f"on to the next AWS account. All exceptions will be aggregated and re-logged at the end of the " + f"sync.", + exc_info=True, + ) + return account_id, exc_result_string + else: + logger.error(f"AWS sync failed for account {account_id}, see traceback; {exception_traceback}") + raise + return None + + def _sync_multiple_accounts( neo4j_session: neo4j.Session, accounts: Dict[str, str], @@ -154,42 +199,28 @@ def _sync_multiple_accounts( num_accounts = len(accounts) - for profile_name, account_id in accounts.items(): - logger.info("Syncing AWS account with ID '%s' using configured profile '%s'.", account_id, profile_name) - common_job_parameters["AWS_ID"] = account_id - if num_accounts == 1: - # Use the default boto3 session because boto3 gets confused if you give it a profile name with 1 account - boto3_session = boto3.Session() - else: - boto3_session = boto3.Session(profile_name=profile_name) - - _autodiscover_accounts(neo4j_session, boto3_session, account_id, sync_tag, common_job_parameters) + with ThreadPoolExecutor(max_workers=cpu_count()) as executor: + for profile_name, account_id in accounts.items(): + if num_accounts == 1: + # Use the default boto3 session because boto3 gets confused if you give it a profile name with 1 account + boto3_session = boto3.Session() + else: + boto3_session = boto3.Session(profile_name=profile_name) - try: - _sync_one_account( - neo4j_session, - boto3_session, + failure: Optional[Tuple[str, str]] = executor.submit( + sync_one_account_runner, + profile_name, account_id, + boto3_session, + neo4j_session, sync_tag, common_job_parameters, - aws_requested_syncs=aws_requested_syncs, # Could be replaced later with per-account requested syncs - ) - except Exception as e: - if aws_best_effort_mode: - timestamp = datetime.datetime.now() - failed_account_ids.append(account_id) - exception_traceback = traceback.TracebackException.from_exception(e) - traceback_string = ''.join(exception_traceback.format()) - exception_tracebacks.append(f'{timestamp} - Exception for account ID: {account_id}\n{traceback_string}') - logger.warning( - f"Caught exception syncing account {account_id}. aws-best-effort-mode is on so we are continuing " - f"on to the next AWS account. All exceptions will be aggregated and re-logged at the end of the " - f"sync.", - exc_info=True, - ) - continue - else: - raise + aws_requested_syncs, + aws_best_effort_mode, + ).result() + if failure: + failed_account_ids.append(failure[0]) + exception_tracebacks.append(failure[1]) if failed_account_ids: logger.error(f'AWS sync failed for accounts {failed_account_ids}') From 90c9d6b2abe8b13dcd18b4b9fd6fb90936792945 Mon Sep 17 00:00:00 2001 From: Alex Chantavy Date: Tue, 14 Mar 2023 16:56:38 -0700 Subject: [PATCH 2/9] Log the number of threads --- cartography/intel/aws/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/cartography/intel/aws/__init__.py b/cartography/intel/aws/__init__.py index 5a66bbbe7..fe923cd79 100644 --- a/cartography/intel/aws/__init__.py +++ b/cartography/intel/aws/__init__.py @@ -198,8 +198,10 @@ def _sync_multiple_accounts( exception_tracebacks = [] num_accounts = len(accounts) + num_threads = cpu_count() - with ThreadPoolExecutor(max_workers=cpu_count()) as executor: + logger.info(f"AWS: Using {num_threads} threads.") + with ThreadPoolExecutor(max_workers=num_threads) as executor: for profile_name, account_id in accounts.items(): if num_accounts == 1: # Use the default boto3 session because boto3 gets confused if you give it a profile name with 1 account From 691a42b9a85b70712d258639b47e6dadffcb6e47 Mon Sep 17 00:00:00 2001 From: Harsh Agarwal Date: Sun, 18 Dec 2022 23:21:36 +0530 Subject: [PATCH 3/9] added neo4j session factory along with threading --- cartography/intel/aws/ec2/vpc.py | 47 +++++++++++++++++++++++++----- cartography/neo4jSessionFactory.py | 34 +++++++++++++++++++++ cartography/sync.py | 5 ++++ 3 files changed, 78 insertions(+), 8 deletions(-) create mode 100644 cartography/neo4jSessionFactory.py diff --git a/cartography/intel/aws/ec2/vpc.py b/cartography/intel/aws/ec2/vpc.py index f62194c48..f90bbd4c0 100644 --- a/cartography/intel/aws/ec2/vpc.py +++ b/cartography/intel/aws/ec2/vpc.py @@ -11,6 +11,9 @@ from cartography.util import run_cleanup_job from cartography.util import timeit +import threading +from cartography.neo4jSessionFactory import factory as neo4jFactory + logger = logging.getLogger(__name__) @@ -58,8 +61,8 @@ def _get_cidr_association_statement(block_type: str) -> str: @timeit def load_cidr_association_set( - neo4j_session: neo4j.Session, vpc_id: str, vpc_data: Dict, block_type: str, - update_tag: int, + neo4j_session: neo4j.Session, vpc_id: str, vpc_data: Dict, block_type: str, + update_tag: int, ) -> None: ingest_statement = _get_cidr_association_statement(block_type) @@ -78,8 +81,8 @@ def load_cidr_association_set( @timeit def load_ec2_vpcs( - neo4j_session: neo4j.Session, data: List[Dict], region: str, current_aws_account_id: str, - update_tag: int, + neo4j_session: neo4j.Session, data: List[Dict], region: str, current_aws_account_id: str, + update_tag: int, ) -> None: # https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-vpcs.html # { @@ -164,13 +167,41 @@ def cleanup_ec2_vpcs(neo4j_session: neo4j.Session, common_job_parameters: Dict) run_cleanup_job('aws_import_vpc_cleanup.json', neo4j_session, common_job_parameters) +def sync_vpc_per_region( + boto3_session: boto3.session.Session, + region: str, + current_aws_account_id: str, + update_tag: int, ) -> None: + logger.info("Syncing EC2 VPC for region '%s' in account '%s'.", region, current_aws_account_id) + with neo4jFactory.get_new_session() as neo4j_thread_session: + data = get_ec2_vpcs(boto3_session, region) + load_ec2_vpcs(neo4j_thread_session, data, region, current_aws_account_id, update_tag) + + @timeit def sync_vpc( - neo4j_session: neo4j.Session, boto3_session: boto3.session.Session, regions: List[str], current_aws_account_id: str, - update_tag: int, common_job_parameters: Dict, + neo4j_session: neo4j.Session, boto3_session: boto3.session.Session, regions: List[str], + current_aws_account_id: str, + update_tag: int, common_job_parameters: Dict, ) -> None: + # for region in regions: + # logger.info("Syncing EC2 VPC for region '%s' in account '%s'.", region, current_aws_account_id) + # data = get_ec2_vpcs(boto3_session, region) + # load_ec2_vpcs(neo4j_session, data, region, current_aws_account_id, update_tag) + + ts = [] + for region in regions: logger.info("Syncing EC2 VPC for region '%s' in account '%s'.", region, current_aws_account_id) - data = get_ec2_vpcs(boto3_session, region) - load_ec2_vpcs(neo4j_session, data, region, current_aws_account_id, update_tag) + t = threading.Thread(target=sync_vpc_per_region, + args=( + boto3_session, region, + current_aws_account_id, + update_tag,)) + t.start() + ts.append(t) + + for t in ts: + t.join() + cleanup_ec2_vpcs(neo4j_session, common_job_parameters) diff --git a/cartography/neo4jSessionFactory.py b/cartography/neo4jSessionFactory.py new file mode 100644 index 000000000..c99ba1134 --- /dev/null +++ b/cartography/neo4jSessionFactory.py @@ -0,0 +1,34 @@ +import logging + +logger = logging.getLogger(__name__) + + +class Neo4JSessionFactory: + _setup = False + _driver = None + _database = None + + def __init__(self): + logger.info("Neo4JFactory Init") + + def initialize(self, neo4j_driver, neo4j_database): + if self._setup: + logger.warning("Reinitializing the Neo4JSessionFactory. It is not allowed.") + return + + logger.info("Setting up the Neo4JSessionFactory") + + self._setup = True + self._driver = neo4j_driver + self._database = neo4j_database + + def get_new_session(self): + if not self._setup: + logger.warning("Neo4JSessionFactory is not setup") + return ClientError + + session = self._driver.session(database=self._database) + return session + + +factory = Neo4JSessionFactory() diff --git a/cartography/sync.py b/cartography/sync.py index 4ac02593c..492c638db 100644 --- a/cartography/sync.py +++ b/cartography/sync.py @@ -30,6 +30,8 @@ from cartography.util import STATUS_FAILURE from cartography.util import STATUS_SUCCESS +from . import neo4jSessionFactory + logger = logging.getLogger(__name__) @@ -78,6 +80,9 @@ def run(self, neo4j_driver: neo4j.Driver, config: Union[Config, argparse.Namespa :param config: Configuration for the sync run. """ logger.info("Starting sync with update tag '%d'", config.update_tag) + + neo4jSessionFactory.factory.initialize(neo4j_driver=neo4j_driver, neo4j_database=config.neo4j_database) + with neo4j_driver.session(database=config.neo4j_database) as neo4j_session: for stage_name, stage_func in self._stages.items(): logger.info("Starting sync stage '%s'", stage_name) From d89d0f4ba62b9df70b7c6094a858ccc9357d6185 Mon Sep 17 00:00:00 2001 From: Harsh Agarwal Date: Wed, 21 Dec 2022 10:43:36 +0530 Subject: [PATCH 4/9] performed linting --- cartography/intel/aws/ec2/vpc.py | 24 ++++++++++-------------- cartography/neo4jSessionFactory.py | 17 ++++++++++------- cartography/sync.py | 3 +-- 3 files changed, 21 insertions(+), 23 deletions(-) diff --git a/cartography/intel/aws/ec2/vpc.py b/cartography/intel/aws/ec2/vpc.py index f90bbd4c0..c549b16b4 100644 --- a/cartography/intel/aws/ec2/vpc.py +++ b/cartography/intel/aws/ec2/vpc.py @@ -1,4 +1,5 @@ import logging +import threading from string import Template from typing import Dict from typing import List @@ -7,13 +8,11 @@ import neo4j from .util import get_botocore_config +from cartography.neo4jSessionFactory import factory as neo4jFactory from cartography.util import aws_handle_regions from cartography.util import run_cleanup_job from cartography.util import timeit -import threading -from cartography.neo4jSessionFactory import factory as neo4jFactory - logger = logging.getLogger(__name__) @@ -184,20 +183,17 @@ def sync_vpc( current_aws_account_id: str, update_tag: int, common_job_parameters: Dict, ) -> None: - # for region in regions: - # logger.info("Syncing EC2 VPC for region '%s' in account '%s'.", region, current_aws_account_id) - # data = get_ec2_vpcs(boto3_session, region) - # load_ec2_vpcs(neo4j_session, data, region, current_aws_account_id, update_tag) - ts = [] for region in regions: - logger.info("Syncing EC2 VPC for region '%s' in account '%s'.", region, current_aws_account_id) - t = threading.Thread(target=sync_vpc_per_region, - args=( - boto3_session, region, - current_aws_account_id, - update_tag,)) + t = threading.Thread( + target=sync_vpc_per_region, + args=( + boto3_session, region, + current_aws_account_id, + update_tag, + ), + ) t.start() ts.append(t) diff --git a/cartography/neo4jSessionFactory.py b/cartography/neo4jSessionFactory.py index c99ba1134..a5ca85cfc 100644 --- a/cartography/neo4jSessionFactory.py +++ b/cartography/neo4jSessionFactory.py @@ -1,4 +1,7 @@ import logging +from typing import Any + +import neo4j logger = logging.getLogger(__name__) @@ -11,7 +14,7 @@ class Neo4JSessionFactory: def __init__(self): logger.info("Neo4JFactory Init") - def initialize(self, neo4j_driver, neo4j_database): + def initialize(self, neo4j_driver: neo4j.Driver, neo4j_database: str) -> None: if self._setup: logger.warning("Reinitializing the Neo4JSessionFactory. It is not allowed.") return @@ -22,13 +25,13 @@ def initialize(self, neo4j_driver, neo4j_database): self._driver = neo4j_driver self._database = neo4j_database - def get_new_session(self): - if not self._setup: - logger.warning("Neo4JSessionFactory is not setup") - return ClientError + def get_new_session(self) -> Any: + if not self._setup or not self._driver: + logger.warning("Neo4J Factory is not initialized.") + return - session = self._driver.session(database=self._database) - return session + new_session = self._driver.session(database=self._database) + return new_session factory = Neo4JSessionFactory() diff --git a/cartography/sync.py b/cartography/sync.py index 492c638db..9de9e69b5 100644 --- a/cartography/sync.py +++ b/cartography/sync.py @@ -25,13 +25,12 @@ import cartography.intel.kubernetes import cartography.intel.oci import cartography.intel.okta +from . import neo4jSessionFactory from cartography.config import Config from cartography.stats import set_stats_client from cartography.util import STATUS_FAILURE from cartography.util import STATUS_SUCCESS -from . import neo4jSessionFactory - logger = logging.getLogger(__name__) From c2d6ed31e657a582455ee89d9ce52deb8547f9a2 Mon Sep 17 00:00:00 2001 From: Alex Chantavy Date: Tue, 14 Mar 2023 20:07:14 -0700 Subject: [PATCH 5/9] Use session factory for parallel by account, revert vpc region changes --- cartography/intel/aws/__init__.py | 43 +++++++++--------- cartography/intel/aws/ec2/vpc.py | 45 ++++--------------- ...ionFactory.py => neo4j_session_factory.py} | 12 ++--- cartography/sync.py | 4 +- 4 files changed, 39 insertions(+), 65 deletions(-) rename cartography/{neo4jSessionFactory.py => neo4j_session_factory.py} (63%) diff --git a/cartography/intel/aws/__init__.py b/cartography/intel/aws/__init__.py index fe923cd79..2a0209602 100644 --- a/cartography/intel/aws/__init__.py +++ b/cartography/intel/aws/__init__.py @@ -19,6 +19,7 @@ from .resources import RESOURCE_FUNCTIONS from cartography.config import Config from cartography.intel.aws.util.common import parse_and_validate_aws_requested_syncs +from cartography.neo4j_session_factory import neo4j_session_factory from cartography.stats import get_stats_client from cartography.util import merge_module_sync_metadata from cartography.util import run_analysis_and_ensure_deps @@ -26,7 +27,6 @@ from cartography.util import run_cleanup_job from cartography.util import timeit - stat_handler = get_stats_client(__name__) logger = logging.getLogger(__name__) @@ -203,26 +203,27 @@ def _sync_multiple_accounts( logger.info(f"AWS: Using {num_threads} threads.") with ThreadPoolExecutor(max_workers=num_threads) as executor: for profile_name, account_id in accounts.items(): - if num_accounts == 1: - # Use the default boto3 session because boto3 gets confused if you give it a profile name with 1 account - boto3_session = boto3.Session() - else: - boto3_session = boto3.Session(profile_name=profile_name) - - failure: Optional[Tuple[str, str]] = executor.submit( - sync_one_account_runner, - profile_name, - account_id, - boto3_session, - neo4j_session, - sync_tag, - common_job_parameters, - aws_requested_syncs, - aws_best_effort_mode, - ).result() - if failure: - failed_account_ids.append(failure[0]) - exception_tracebacks.append(failure[1]) + with neo4j_session_factory.get_new_session() as neo4j_thread_session: + if num_accounts == 1: + # Use the default boto3 session because boto3 gets confused if you give it a profile name w/ 1 acc. + boto3_session = boto3.Session() + else: + boto3_session = boto3.Session(profile_name=profile_name) + + failure: Optional[Tuple[str, str]] = executor.submit( + sync_one_account_runner, + profile_name, + account_id, + boto3_session, + neo4j_thread_session, + sync_tag, + common_job_parameters, + aws_requested_syncs, + aws_best_effort_mode, + ).result() + if failure: + failed_account_ids.append(failure[0]) + exception_tracebacks.append(failure[1]) if failed_account_ids: logger.error(f'AWS sync failed for accounts {failed_account_ids}') diff --git a/cartography/intel/aws/ec2/vpc.py b/cartography/intel/aws/ec2/vpc.py index c549b16b4..f62194c48 100644 --- a/cartography/intel/aws/ec2/vpc.py +++ b/cartography/intel/aws/ec2/vpc.py @@ -1,5 +1,4 @@ import logging -import threading from string import Template from typing import Dict from typing import List @@ -8,7 +7,6 @@ import neo4j from .util import get_botocore_config -from cartography.neo4jSessionFactory import factory as neo4jFactory from cartography.util import aws_handle_regions from cartography.util import run_cleanup_job from cartography.util import timeit @@ -60,8 +58,8 @@ def _get_cidr_association_statement(block_type: str) -> str: @timeit def load_cidr_association_set( - neo4j_session: neo4j.Session, vpc_id: str, vpc_data: Dict, block_type: str, - update_tag: int, + neo4j_session: neo4j.Session, vpc_id: str, vpc_data: Dict, block_type: str, + update_tag: int, ) -> None: ingest_statement = _get_cidr_association_statement(block_type) @@ -80,8 +78,8 @@ def load_cidr_association_set( @timeit def load_ec2_vpcs( - neo4j_session: neo4j.Session, data: List[Dict], region: str, current_aws_account_id: str, - update_tag: int, + neo4j_session: neo4j.Session, data: List[Dict], region: str, current_aws_account_id: str, + update_tag: int, ) -> None: # https://docs.aws.amazon.com/cli/latest/reference/ec2/describe-vpcs.html # { @@ -166,38 +164,13 @@ def cleanup_ec2_vpcs(neo4j_session: neo4j.Session, common_job_parameters: Dict) run_cleanup_job('aws_import_vpc_cleanup.json', neo4j_session, common_job_parameters) -def sync_vpc_per_region( - boto3_session: boto3.session.Session, - region: str, - current_aws_account_id: str, - update_tag: int, ) -> None: - logger.info("Syncing EC2 VPC for region '%s' in account '%s'.", region, current_aws_account_id) - with neo4jFactory.get_new_session() as neo4j_thread_session: - data = get_ec2_vpcs(boto3_session, region) - load_ec2_vpcs(neo4j_thread_session, data, region, current_aws_account_id, update_tag) - - @timeit def sync_vpc( - neo4j_session: neo4j.Session, boto3_session: boto3.session.Session, regions: List[str], - current_aws_account_id: str, - update_tag: int, common_job_parameters: Dict, + neo4j_session: neo4j.Session, boto3_session: boto3.session.Session, regions: List[str], current_aws_account_id: str, + update_tag: int, common_job_parameters: Dict, ) -> None: - ts = [] - for region in regions: - t = threading.Thread( - target=sync_vpc_per_region, - args=( - boto3_session, region, - current_aws_account_id, - update_tag, - ), - ) - t.start() - ts.append(t) - - for t in ts: - t.join() - + logger.info("Syncing EC2 VPC for region '%s' in account '%s'.", region, current_aws_account_id) + data = get_ec2_vpcs(boto3_session, region) + load_ec2_vpcs(neo4j_session, data, region, current_aws_account_id, update_tag) cleanup_ec2_vpcs(neo4j_session, common_job_parameters) diff --git a/cartography/neo4jSessionFactory.py b/cartography/neo4j_session_factory.py similarity index 63% rename from cartography/neo4jSessionFactory.py rename to cartography/neo4j_session_factory.py index a5ca85cfc..39b16bc89 100644 --- a/cartography/neo4jSessionFactory.py +++ b/cartography/neo4j_session_factory.py @@ -6,20 +6,20 @@ logger = logging.getLogger(__name__) -class Neo4JSessionFactory: +class Neo4jSessionFactory: _setup = False _driver = None _database = None def __init__(self): - logger.info("Neo4JFactory Init") + logger.info("Neo4j neo4j_session_factory init") def initialize(self, neo4j_driver: neo4j.Driver, neo4j_database: str) -> None: if self._setup: - logger.warning("Reinitializing the Neo4JSessionFactory. It is not allowed.") + logger.warning("Reinitializing the Neo4j session neo4j_session_factory. It is not allowed.") return - logger.info("Setting up the Neo4JSessionFactory") + logger.info("Setting up the Neo4j Session Factory") self._setup = True self._driver = neo4j_driver @@ -27,11 +27,11 @@ def initialize(self, neo4j_driver: neo4j.Driver, neo4j_database: str) -> None: def get_new_session(self) -> Any: if not self._setup or not self._driver: - logger.warning("Neo4J Factory is not initialized.") + logger.warning("Neo4j Factory is not initialized.") return new_session = self._driver.session(database=self._database) return new_session -factory = Neo4JSessionFactory() +neo4j_session_factory = Neo4jSessionFactory() diff --git a/cartography/sync.py b/cartography/sync.py index 9de9e69b5..e8ba95ac3 100644 --- a/cartography/sync.py +++ b/cartography/sync.py @@ -9,6 +9,7 @@ import neo4j.exceptions from neo4j import GraphDatabase +from neo4j_session_factory import neo4j_session_factory from statsd import StatsClient import cartography.intel.analysis @@ -25,7 +26,6 @@ import cartography.intel.kubernetes import cartography.intel.oci import cartography.intel.okta -from . import neo4jSessionFactory from cartography.config import Config from cartography.stats import set_stats_client from cartography.util import STATUS_FAILURE @@ -80,7 +80,7 @@ def run(self, neo4j_driver: neo4j.Driver, config: Union[Config, argparse.Namespa """ logger.info("Starting sync with update tag '%d'", config.update_tag) - neo4jSessionFactory.factory.initialize(neo4j_driver=neo4j_driver, neo4j_database=config.neo4j_database) + neo4j_session_factory.initialize(neo4j_driver=neo4j_driver, neo4j_database=config.neo4j_database) with neo4j_driver.session(database=config.neo4j_database) as neo4j_session: for stage_name, stage_func in self._stages.items(): From ed4586b1af1ca7280ecdc5262f63fb15bd4815b0 Mon Sep 17 00:00:00 2001 From: Alex Chantavy Date: Tue, 14 Mar 2023 21:52:36 -0700 Subject: [PATCH 6/9] Unit test factory --- cartography/neo4j_session_factory.py | 15 ++--- .../cartography/test_neo4j_session_factory.py | 60 +++++++++++++++++++ 2 files changed, 68 insertions(+), 7 deletions(-) create mode 100644 tests/unit/cartography/test_neo4j_session_factory.py diff --git a/cartography/neo4j_session_factory.py b/cartography/neo4j_session_factory.py index 39b16bc89..e3273b1d4 100644 --- a/cartography/neo4j_session_factory.py +++ b/cartography/neo4j_session_factory.py @@ -1,5 +1,4 @@ import logging -from typing import Any import neo4j @@ -12,23 +11,25 @@ class Neo4jSessionFactory: _database = None def __init__(self): - logger.info("Neo4j neo4j_session_factory init") + logger.info("neo4j_session_factory init") def initialize(self, neo4j_driver: neo4j.Driver, neo4j_database: str) -> None: if self._setup: - logger.warning("Reinitializing the Neo4j session neo4j_session_factory. It is not allowed.") + logger.warning("Reinitializing the Neo4j session factory is not allowed; doing nothing.") return - logger.info("Setting up the Neo4j Session Factory") + logger.info("Setting up the Neo4j session factory") self._setup = True self._driver = neo4j_driver self._database = neo4j_database - def get_new_session(self) -> Any: + def get_new_session(self) -> neo4j.Session: if not self._setup or not self._driver: - logger.warning("Neo4j Factory is not initialized.") - return + raise RuntimeError( + "Neo4j session factory is not initialized. " + "Make sure that initialize() is called before get_new_session()." + ) new_session = self._driver.session(database=self._database) return new_session diff --git a/tests/unit/cartography/test_neo4j_session_factory.py b/tests/unit/cartography/test_neo4j_session_factory.py new file mode 100644 index 000000000..95b832ba3 --- /dev/null +++ b/tests/unit/cartography/test_neo4j_session_factory.py @@ -0,0 +1,60 @@ +import unittest +from unittest import mock + +import neo4j +import pytest + +from cartography.neo4j_session_factory import Neo4jSessionFactory + + +def test_initialize(): + # Arrange + neo4j_session_factory = Neo4jSessionFactory() + neo4j_driver_mock = mock.Mock(spec=neo4j.Driver) + + # Act + neo4j_session_factory.initialize(neo4j_driver_mock, "test_db") + + # Assert + assert neo4j_session_factory._driver == neo4j_driver_mock + assert neo4j_session_factory._database == "test_db" + + +def test_get_new_session(): + # Arrange + neo4j_session_factory = Neo4jSessionFactory() + neo4j_driver_mock = mock.Mock(spec=neo4j.Driver) + neo4j_session_factory.initialize(neo4j_driver_mock, "test_db") + neo4j_session_mock = mock.Mock() + neo4j_driver_mock.session.return_value = neo4j_session_mock + + # Act + new_session = neo4j_session_factory.get_new_session() + + # Assert + assert new_session == neo4j_session_mock + + +class TestNeo4jSessionFactory(unittest.TestCase): + def setUp(self): + self.driver_mock = mock.Mock(spec=neo4j.Driver) + + def test_reinitialize(self): + # Arrange + neo4j_session_factory = Neo4jSessionFactory() + neo4j_session_factory.initialize(self.driver_mock, "test_db") + + # Act + with self.assertLogs(level="WARNING") as log: + neo4j_session_factory.initialize(self.driver_mock, "test_db") + + # Assert + self.assertIn("Reinitializing the Neo4j session", log.output[0]) + + +def test_neo4j_session_factory_get_new_session_not_initialized(): + neo4j_session_factory = Neo4jSessionFactory() + + with pytest.raises(RuntimeError, match="Neo4j session factory is not initialized"): + new_session = neo4j_session_factory.get_new_session() + assert new_session is None From fcd072fa83255fe1107ad745c126c6cea8add7e5 Mon Sep 17 00:00:00 2001 From: Alex Chantavy Date: Tue, 14 Mar 2023 21:56:09 -0700 Subject: [PATCH 7/9] Linter --- cartography/neo4j_session_factory.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cartography/neo4j_session_factory.py b/cartography/neo4j_session_factory.py index e3273b1d4..23c5a8bef 100644 --- a/cartography/neo4j_session_factory.py +++ b/cartography/neo4j_session_factory.py @@ -28,7 +28,7 @@ def get_new_session(self) -> neo4j.Session: if not self._setup or not self._driver: raise RuntimeError( "Neo4j session factory is not initialized. " - "Make sure that initialize() is called before get_new_session()." + "Make sure that initialize() is called before get_new_session().", ) new_session = self._driver.session(database=self._database) From cc65b3f66ec674915d50616e31bb3f468f30f21b Mon Sep 17 00:00:00 2001 From: Alex Chantavy Date: Tue, 14 Mar 2023 21:59:52 -0700 Subject: [PATCH 8/9] fix import --- cartography/sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cartography/sync.py b/cartography/sync.py index e8ba95ac3..2b25fe4b9 100644 --- a/cartography/sync.py +++ b/cartography/sync.py @@ -9,7 +9,7 @@ import neo4j.exceptions from neo4j import GraphDatabase -from neo4j_session_factory import neo4j_session_factory +from cartography.neo4j_session_factory import neo4j_session_factory from statsd import StatsClient import cartography.intel.analysis From d061afba177537a93d1d7a6692dd7f62f4e23cdd Mon Sep 17 00:00:00 2001 From: Alex Chantavy Date: Tue, 14 Mar 2023 22:01:06 -0700 Subject: [PATCH 9/9] linter agian --- cartography/sync.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cartography/sync.py b/cartography/sync.py index 2b25fe4b9..53806a535 100644 --- a/cartography/sync.py +++ b/cartography/sync.py @@ -9,7 +9,6 @@ import neo4j.exceptions from neo4j import GraphDatabase -from cartography.neo4j_session_factory import neo4j_session_factory from statsd import StatsClient import cartography.intel.analysis @@ -27,6 +26,7 @@ import cartography.intel.oci import cartography.intel.okta from cartography.config import Config +from cartography.neo4j_session_factory import neo4j_session_factory from cartography.stats import set_stats_client from cartography.util import STATUS_FAILURE from cartography.util import STATUS_SUCCESS