Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add option to cli to create neo4j db constraints #412

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 30 additions & 6 deletions src/metakb/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,15 @@ async def transform_file(
)


def _get_driver(db_url: str, db_creds: str | None) -> Generator[Driver, None, None]:
def _get_driver(
db_url: str, db_creds: str | None, add_constraints: bool
) -> Generator[Driver, None, None]:
"""Acquire Neo4j graph driver.

:param db_url: URL endpoint for the application Neo4j database.
:param db_creds: DB username and password, separated by a colon, e.g.
``"username:password"``.
:param add_constraints: Whether or not to create Neo4j database constraints.
:return: Graph driver instance
"""
if not db_creds:
Expand All @@ -348,7 +351,9 @@ def _get_driver(db_url: str, db_creds: str | None) -> Generator[Driver, None, No
_help_msg(
f"Argument to --db_credentials appears invalid. Got '{db_creds}'. Should follow pattern 'username:password'."
)
driver = get_driver(uri=db_url, credentials=credentials)
driver = get_driver(
uri=db_url, credentials=credentials, add_constraints=add_constraints
)
yield driver
driver.close()

Expand Down Expand Up @@ -382,13 +387,19 @@ def clear_graph(
``"username:password"``.
:param keep_constraints: if True, don't clear graph constraints
""" # noqa: D301
driver = next(_get_driver(db_url, db_credentials))
driver = next(_get_driver(db_url, db_credentials, add_constraints=False))
clear_metakb_graph(driver, keep_constraints)


@cli.command()
@click.option("--db_url", "-u", default="", help=_neo4j_db_url_description)
@click.option("--db_credentials", "-c", help=_neo4j_creds_description)
@click.option(
"--add_constraints",
is_flag=True,
default=False,
help="if true, create neo4j database constraints",
)
@click.option(
"--from_s3",
"-s",
Expand All @@ -402,7 +413,11 @@ def clear_graph(
nargs=-1,
)
def load_cdm(
db_url: str, db_credentials: str | None, from_s3: bool, cdm_files: tuple[Path, ...]
db_url: str,
db_credentials: str | None,
add_constraints: bool,
from_s3: bool,
cdm_files: tuple[Path, ...],
) -> None:
"""Load one or more CDM_FILEs into Neo4j graph.

Expand Down Expand Up @@ -430,6 +445,7 @@ def load_cdm(
:param db_url: URL endpoint for the application Neo4j database.
:param db_credentials: DB username and password, separated by a colon, e.g.
``"username:password"``.
:param add_constraints: Whether or not to create Neo4j database constraints.
:param from_s3: Skip data harvest/transform and load latest existing CDM files from
VICC S3 bucket. Exclusive with ``cdm_file`` arguments.
:param cdm_files: tuple of specific file(s) to load from. If empty, just get latest
Expand All @@ -441,7 +457,7 @@ def load_cdm(
start = timer()
_echo_info("Loading Neo4j database...")

driver = next(_get_driver(db_url, db_credentials))
driver = next(_get_driver(db_url, db_credentials, add_constraints))

if cdm_files:
for file in cdm_files:
Expand All @@ -468,6 +484,12 @@ def load_cdm(
@cli.command()
@click.option("--db_url", "-u", default="", help=_neo4j_db_url_description)
@click.option("--db_credentials", "-c", help=_neo4j_creds_description)
@click.option(
"--add_constraints",
is_flag=True,
default=False,
help="if true, create neo4j database constraints",
)
@click.option("--normalizer_db_url", "-n", help=_normalizer_db_url_description)
@click.option(
"--refresh_source_caches",
Expand All @@ -487,6 +509,7 @@ def load_cdm(
async def update(
db_url: str,
db_credentials: str | None,
add_constraints: bool,
normalizer_db_url: str | None,
refresh_source_caches: bool,
sources: tuple[SourceName, ...],
Expand Down Expand Up @@ -515,6 +538,7 @@ async def update(
:param db_url: URL endpoint for the application Neo4j database.
:param db_credentials: DB username and password, separated by a colon, e.g.
``"username:password"``.
:param add_constraints: Whether or not to create Neo4j database constraints.
:param normalizer_db_url: URL endpoint of normalizers DynamoDB database. If not
given, defaults to the configuration rules of the individual normalizers.
:param refresh_source_caches: ``True`` if source caches, i.e. CIViCPy, should be
Expand All @@ -528,7 +552,7 @@ async def update(
start = timer()
_echo_info("Loading Neo4j database...")

driver = next(_get_driver(db_url, db_credentials))
driver = next(_get_driver(db_url, db_credentials, add_constraints))

if not sources:
sources = tuple(SourceName)
Expand Down
Loading