From c5f52c6e308b03cd06276c2babc8d22f6abf7d04 Mon Sep 17 00:00:00 2001 From: Tyler Burton Date: Wed, 8 May 2024 11:52:09 -0500 Subject: [PATCH] updates method to return lookup table of identifiers and ids --- database/interface.py | 11 +++++++---- tests/unit/database/test_db.py | 15 +++++++++++---- tests/unit/test_load_manager.py | 5 ----- 3 files changed, 18 insertions(+), 13 deletions(-) diff --git a/database/interface.py b/database/interface.py index 9e2391ba..3f1d360e 100644 --- a/database/interface.py +++ b/database/interface.py @@ -1,4 +1,5 @@ import os +import uuid from sqlalchemy import create_engine, inspect, or_ from sqlalchemy.exc import NoResultFound from sqlalchemy.orm import scoped_session, sessionmaker @@ -223,22 +224,24 @@ def add_harvest_record(self, record_data): self.db.rollback() return None - def add_harvest_records(self, records_data: list) -> bool: + def add_harvest_records(self, records_data: list) -> dict: """ Add many records at once :param list records_data: List of records with unique UUIDs - :return bool success of operation + :return dict id_lookup_table: identifiers -> ids :raises Exception: if the records_data contains records with errors """ try: + id_lookup_table = {} for i, record_data in enumerate(records_data): - new_record = HarvestRecord(**record_data) + new_record = HarvestRecord(id=str(uuid.uuid4()), **record_data) + id_lookup_table[new_record.identifier] = new_record.id self.db.add(new_record) if i % 1000 == 0: self.db.flush() self.db.commit() - return True + return id_lookup_table except Exception as e: print("Error:", e) self.db.rollback() diff --git a/tests/unit/database/test_db.py b/tests/unit/database/test_db.py index 7f20949a..29338c89 100644 --- a/tests/unit/database/test_db.py +++ b/tests/unit/database/test_db.py @@ -116,10 +116,17 @@ def test_add_harvest_records( interface.add_harvest_source(source_data_dcatus) interface.add_harvest_job(job_data_dcatus) - records = [record_data_dcatus] * 10 - success = interface.add_harvest_records(records) - assert success is True - assert len(interface.get_all_harvest_records()) == 10 + records = [] + for i in range(10): + new_record = record_data_dcatus.copy() + new_record["identifier"] = f"test-identifier-{i}" + records.append(new_record) + + id_lookup_table = interface.add_harvest_records(records) + db_records = interface.get_all_harvest_records() + assert len(id_lookup_table) == 10 + assert len(db_records) == 10 + assert id_lookup_table[db_records[0]["identifier"]] == db_records[0]["id"] def test_add_harvest_job_with_id( self, interface, organization_data, source_data_dcatus, job_data_dcatus diff --git a/tests/unit/test_load_manager.py b/tests/unit/test_load_manager.py index ac80ed58..ed307b1b 100644 --- a/tests/unit/test_load_manager.py +++ b/tests/unit/test_load_manager.py @@ -16,11 +16,6 @@ def mock_bad_cf_index(monkeypatch): monkeypatch.setenv("CF_INSTANCE_INDEX", "1") -@pytest.fixture(autouse=True) -def mock_lm_config(monkeypatch): - monkeypatch.setenv("LM_RUNNER_APP_GUID", "f4ab7f86-bee0-44fd-8806-1dca7f8e215a") - - class TestLoadManager: @patch.object(HarvesterDBInterface, "update_harvest_job") @patch.object(CFHandler, "start_task")