From edd2b046b92255eb360a79aefdca68021a74cee5 Mon Sep 17 00:00:00 2001 From: Michael Terry Date: Tue, 10 Oct 2023 14:41:49 -0400 Subject: [PATCH] fix: guard against accidental deletes with ndjson output format The ndjson output formatter deletes all existing content in its target directory, because it is not an incremental formatter. But this introduces the possibility of accidentally using the ndjson output format on a delta lake folder or some other important folder. So this commit checks the contents of the target output dir and errors out if there are odd looking files in there. --- cumulus_etl/formats/batched_files.py | 23 +++++++++ tests/etl/base.py | 14 ++---- tests/formats/__init__.py | 0 tests/{ => formats}/test_deltalake.py | 0 tests/formats/test_ndjson.py | 72 +++++++++++++++++++++++++++ tests/test_bulk_export.py | 10 ++-- tests/utils.py | 7 +++ 7 files changed, 112 insertions(+), 14 deletions(-) create mode 100644 tests/formats/__init__.py rename tests/{ => formats}/test_deltalake.py (100%) create mode 100644 tests/formats/test_ndjson.py diff --git a/cumulus_etl/formats/batched_files.py b/cumulus_etl/formats/batched_files.py index f5cf83b5..de855ddf 100644 --- a/cumulus_etl/formats/batched_files.py +++ b/cumulus_etl/formats/batched_files.py @@ -1,7 +1,9 @@ """An implementation of Format designed to write in batches of files""" import abc +import re +from cumulus_etl import errors, store from cumulus_etl.formats.base import Format from cumulus_etl.formats.batch import Batch @@ -42,11 +44,32 @@ def __init__(self, *args, **kwargs) -> None: # Note: There is a real issue here where Athena will see invalid results until we've written all # our files out. Use the deltalake format to get atomic updates. parent_dir = self.root.joinpath(self.dbname) + self._confirm_no_unknown_files_exist(parent_dir) try: self.root.rm(parent_dir, recursive=True) except FileNotFoundError: pass + def _confirm_no_unknown_files_exist(self, folder: str) -> None: + """ + Errors out if any unknown files exist in the target dir already. + + This is designed to prevent accidents. + """ + try: + filenames = [path.split("/")[-1] for path in store.Root(folder).ls()] + except FileNotFoundError: + return # folder doesn't exist, we're good! + + allowed_pattern = re.compile(rf"{self.dbname}\.[0-9]+\.{self.suffix}") + if not all(map(allowed_pattern.fullmatch, filenames)): + errors.fatal( + f"There are unexpected files in the output folder '{folder}'.\n" + f"Please confirm you are using the right output format.\n" + f"If so, delete the output folder and try again.", + errors.FOLDER_NOT_EMPTY, + ) + def _write_one_batch(self, batch: Batch) -> None: """Writes the whole dataframe to a single file""" self.root.makedirs(self.root.joinpath(self.dbname)) diff --git a/tests/etl/base.py b/tests/etl/base.py index 71171f88..7062e499 100644 --- a/tests/etl/base.py +++ b/tests/etl/base.py @@ -101,17 +101,17 @@ def setUp(self) -> None: super().setUp() client = fhir.FhirClient("http://localhost/", []) - self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with - self.input_dir = os.path.join(self.tmpdir.name, "input") - self.phi_dir = os.path.join(self.tmpdir.name, "phi") - self.errors_dir = os.path.join(self.tmpdir.name, "errors") + self.tmpdir = self.make_tempdir() + self.input_dir = os.path.join(self.tmpdir, "input") + self.phi_dir = os.path.join(self.tmpdir, "phi") + self.errors_dir = os.path.join(self.tmpdir, "errors") os.makedirs(self.input_dir) os.makedirs(self.phi_dir) self.job_config = JobConfig( self.input_dir, self.input_dir, - self.tmpdir.name, + self.tmpdir, self.phi_dir, "ndjson", "ndjson", @@ -144,10 +144,6 @@ def make_formatter(dbname: str, group_field: str = None, resource_type: str = No # Keeps consistent IDs shutil.copy(os.path.join(self.datadir, "simple/codebook.json"), self.phi_dir) - def tearDown(self) -> None: - super().tearDown() - self.tmpdir = None - def make_json(self, filename, resource_id, **kwargs): common.write_json( os.path.join(self.input_dir, f"{filename}.ndjson"), {"resourceType": "Test", **kwargs, "id": resource_id} diff --git a/tests/formats/__init__.py b/tests/formats/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_deltalake.py b/tests/formats/test_deltalake.py similarity index 100% rename from tests/test_deltalake.py rename to tests/formats/test_deltalake.py diff --git a/tests/formats/test_ndjson.py b/tests/formats/test_ndjson.py new file mode 100644 index 00000000..593c45b2 --- /dev/null +++ b/tests/formats/test_ndjson.py @@ -0,0 +1,72 @@ +"""Tests for ndjson output format support""" + +import os + +import ddt + +from cumulus_etl import formats, store +from cumulus_etl.formats.ndjson import NdjsonFormat +from tests import utils + + +@ddt.ddt +class TestNdjsonFormat(utils.AsyncTestCase): + """ + Test case for the ndjson format writer. + + i.e. tests for ndjson.py + """ + + def setUp(self): + super().setUp() + self.output_tempdir = self.make_tempdir() + self.root = store.Root(self.output_tempdir) + NdjsonFormat.initialize_class(self.root) + + @staticmethod + def df(**kwargs) -> list[dict]: + """ + Creates a dummy Table with ids & values equal to each kwarg provided. + """ + return [{"id": k, "value": v} for k, v in kwargs.items()] + + def store( + self, + rows: list[dict], + batch_index: int = 10, + ) -> bool: + """ + Writes a single batch of data to the output dir. + + :param rows: the data to insert + :param batch_index: which batch number this is, defaulting to 10 to avoid triggering any first/last batch logic + """ + ndjson = NdjsonFormat(self.root, "condition") + batch = formats.Batch(rows, index=batch_index) + return ndjson.write_records(batch) + + @ddt.data( + (None, True), + ([], True), + (["condition.1234.ndjson", "condition.22.ndjson"], True), + (["condition.ndjson"], False), + (["condition.000.parquet"], False), + (["patient.000.ndjson"], False), + ) + @ddt.unpack + def test_handles_existing_files(self, files: None | list[str], is_ok: bool): + """Verify that we bail out if any weird files already exist in the output""" + dbpath = self.root.joinpath("condition") + if files is not None: + os.makedirs(dbpath) + for file in files: + with open(f"{dbpath}/{file}", "w", encoding="utf8") as f: + f.write('{"id": "A"}') + + if is_ok: + self.store([{"id": "B"}], batch_index=0) + self.assertEqual(["condition.000.ndjson"], os.listdir(dbpath)) + else: + with self.assertRaises(SystemExit): + self.store([{"id": "B"}]) + self.assertEqual(files or [], os.listdir(dbpath)) diff --git a/tests/test_bulk_export.py b/tests/test_bulk_export.py index 2001118e..352dd3eb 100644 --- a/tests/test_bulk_export.py +++ b/tests/test_bulk_export.py @@ -25,11 +25,11 @@ class TestBulkExporter(AsyncTestCase): def setUp(self): super().setUp() - self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + self.tmpdir = self.make_tempdir() self.server = mock.AsyncMock() def make_exporter(self, **kwargs) -> BulkExporter: - return BulkExporter(self.server, ["Condition", "Patient"], "https://localhost/", self.tmpdir.name, **kwargs) + return BulkExporter(self.server, ["Condition", "Patient"], "https://localhost/", self.tmpdir, **kwargs) async def export(self, **kwargs) -> BulkExporter: exporter = self.make_exporter(**kwargs) @@ -79,9 +79,9 @@ async def test_happy_path(self): self.server.request.call_args_list, ) - self.assertEqual({"type": "Condition1"}, common.read_json(f"{self.tmpdir.name}/Condition.000.ndjson")) - self.assertEqual({"type": "Condition2"}, common.read_json(f"{self.tmpdir.name}/Condition.001.ndjson")) - self.assertEqual({"type": "Patient1"}, common.read_json(f"{self.tmpdir.name}/Patient.000.ndjson")) + self.assertEqual({"type": "Condition1"}, common.read_json(f"{self.tmpdir}/Condition.000.ndjson")) + self.assertEqual({"type": "Condition2"}, common.read_json(f"{self.tmpdir}/Condition.001.ndjson")) + self.assertEqual({"type": "Patient1"}, common.read_json(f"{self.tmpdir}/Patient.000.ndjson")) async def test_since_until(self): """Verify that we send since & until parameters correctly to the server""" diff --git a/tests/utils.py b/tests/utils.py index 34fa72ca..f3be8f50 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -7,6 +7,7 @@ import inspect import json import os +import tempfile import time import tracemalloc import unittest @@ -46,6 +47,12 @@ def setUp(self): # Make it easy to grab test data, regardless of where the test is self.datadir = os.path.join(os.path.dirname(__file__), "data") + def make_tempdir(self) -> str: + """Creates a temporary dir that will be automatically cleaned up""" + tempdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with + self.addCleanup(tempdir.cleanup) + return tempdir.name + def patch(self, *args, **kwargs) -> mock.Mock: """Syntactic sugar to ease making a mock over a test's lifecycle, without decorators""" patcher = mock.patch(*args, **kwargs)