Skip to content

Commit

Permalink
Merge pull request #281 from smart-on-fhir/mikix/ndjson-safety
Browse files Browse the repository at this point in the history
fix: guard against accidental deletes with ndjson output format
  • Loading branch information
mikix authored Oct 10, 2023
2 parents 2966ea8 + edd2b04 commit 494998f
Show file tree
Hide file tree
Showing 7 changed files with 112 additions and 14 deletions.
23 changes: 23 additions & 0 deletions cumulus_etl/formats/batched_files.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""An implementation of Format designed to write in batches of files"""

import abc
import re

from cumulus_etl import errors, store
from cumulus_etl.formats.base import Format
from cumulus_etl.formats.batch import Batch

Expand Down Expand Up @@ -42,11 +44,32 @@ def __init__(self, *args, **kwargs) -> None:
# Note: There is a real issue here where Athena will see invalid results until we've written all
# our files out. Use the deltalake format to get atomic updates.
parent_dir = self.root.joinpath(self.dbname)
self._confirm_no_unknown_files_exist(parent_dir)
try:
self.root.rm(parent_dir, recursive=True)
except FileNotFoundError:
pass

def _confirm_no_unknown_files_exist(self, folder: str) -> None:
"""
Errors out if any unknown files exist in the target dir already.
This is designed to prevent accidents.
"""
try:
filenames = [path.split("/")[-1] for path in store.Root(folder).ls()]
except FileNotFoundError:
return # folder doesn't exist, we're good!

allowed_pattern = re.compile(rf"{self.dbname}\.[0-9]+\.{self.suffix}")
if not all(map(allowed_pattern.fullmatch, filenames)):
errors.fatal(
f"There are unexpected files in the output folder '{folder}'.\n"
f"Please confirm you are using the right output format.\n"
f"If so, delete the output folder and try again.",
errors.FOLDER_NOT_EMPTY,
)

def _write_one_batch(self, batch: Batch) -> None:
"""Writes the whole dataframe to a single file"""
self.root.makedirs(self.root.joinpath(self.dbname))
Expand Down
14 changes: 5 additions & 9 deletions tests/etl/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,17 +101,17 @@ def setUp(self) -> None:
super().setUp()

client = fhir.FhirClient("http://localhost/", [])
self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.input_dir = os.path.join(self.tmpdir.name, "input")
self.phi_dir = os.path.join(self.tmpdir.name, "phi")
self.errors_dir = os.path.join(self.tmpdir.name, "errors")
self.tmpdir = self.make_tempdir()
self.input_dir = os.path.join(self.tmpdir, "input")
self.phi_dir = os.path.join(self.tmpdir, "phi")
self.errors_dir = os.path.join(self.tmpdir, "errors")
os.makedirs(self.input_dir)
os.makedirs(self.phi_dir)

self.job_config = JobConfig(
self.input_dir,
self.input_dir,
self.tmpdir.name,
self.tmpdir,
self.phi_dir,
"ndjson",
"ndjson",
Expand Down Expand Up @@ -144,10 +144,6 @@ def make_formatter(dbname: str, group_field: str = None, resource_type: str = No
# Keeps consistent IDs
shutil.copy(os.path.join(self.datadir, "simple/codebook.json"), self.phi_dir)

def tearDown(self) -> None:
super().tearDown()
self.tmpdir = None

def make_json(self, filename, resource_id, **kwargs):
common.write_json(
os.path.join(self.input_dir, f"{filename}.ndjson"), {"resourceType": "Test", **kwargs, "id": resource_id}
Expand Down
Empty file added tests/formats/__init__.py
Empty file.
File renamed without changes.
72 changes: 72 additions & 0 deletions tests/formats/test_ndjson.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
"""Tests for ndjson output format support"""

import os

import ddt

from cumulus_etl import formats, store
from cumulus_etl.formats.ndjson import NdjsonFormat
from tests import utils


@ddt.ddt
class TestNdjsonFormat(utils.AsyncTestCase):
"""
Test case for the ndjson format writer.
i.e. tests for ndjson.py
"""

def setUp(self):
super().setUp()
self.output_tempdir = self.make_tempdir()
self.root = store.Root(self.output_tempdir)
NdjsonFormat.initialize_class(self.root)

@staticmethod
def df(**kwargs) -> list[dict]:
"""
Creates a dummy Table with ids & values equal to each kwarg provided.
"""
return [{"id": k, "value": v} for k, v in kwargs.items()]

def store(
self,
rows: list[dict],
batch_index: int = 10,
) -> bool:
"""
Writes a single batch of data to the output dir.
:param rows: the data to insert
:param batch_index: which batch number this is, defaulting to 10 to avoid triggering any first/last batch logic
"""
ndjson = NdjsonFormat(self.root, "condition")
batch = formats.Batch(rows, index=batch_index)
return ndjson.write_records(batch)

@ddt.data(
(None, True),
([], True),
(["condition.1234.ndjson", "condition.22.ndjson"], True),
(["condition.ndjson"], False),
(["condition.000.parquet"], False),
(["patient.000.ndjson"], False),
)
@ddt.unpack
def test_handles_existing_files(self, files: None | list[str], is_ok: bool):
"""Verify that we bail out if any weird files already exist in the output"""
dbpath = self.root.joinpath("condition")
if files is not None:
os.makedirs(dbpath)
for file in files:
with open(f"{dbpath}/{file}", "w", encoding="utf8") as f:
f.write('{"id": "A"}')

if is_ok:
self.store([{"id": "B"}], batch_index=0)
self.assertEqual(["condition.000.ndjson"], os.listdir(dbpath))
else:
with self.assertRaises(SystemExit):
self.store([{"id": "B"}])
self.assertEqual(files or [], os.listdir(dbpath))
10 changes: 5 additions & 5 deletions tests/test_bulk_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ class TestBulkExporter(AsyncTestCase):

def setUp(self):
super().setUp()
self.tmpdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.tmpdir = self.make_tempdir()
self.server = mock.AsyncMock()

def make_exporter(self, **kwargs) -> BulkExporter:
return BulkExporter(self.server, ["Condition", "Patient"], "https://localhost/", self.tmpdir.name, **kwargs)
return BulkExporter(self.server, ["Condition", "Patient"], "https://localhost/", self.tmpdir, **kwargs)

async def export(self, **kwargs) -> BulkExporter:
exporter = self.make_exporter(**kwargs)
Expand Down Expand Up @@ -79,9 +79,9 @@ async def test_happy_path(self):
self.server.request.call_args_list,
)

self.assertEqual({"type": "Condition1"}, common.read_json(f"{self.tmpdir.name}/Condition.000.ndjson"))
self.assertEqual({"type": "Condition2"}, common.read_json(f"{self.tmpdir.name}/Condition.001.ndjson"))
self.assertEqual({"type": "Patient1"}, common.read_json(f"{self.tmpdir.name}/Patient.000.ndjson"))
self.assertEqual({"type": "Condition1"}, common.read_json(f"{self.tmpdir}/Condition.000.ndjson"))
self.assertEqual({"type": "Condition2"}, common.read_json(f"{self.tmpdir}/Condition.001.ndjson"))
self.assertEqual({"type": "Patient1"}, common.read_json(f"{self.tmpdir}/Patient.000.ndjson"))

async def test_since_until(self):
"""Verify that we send since & until parameters correctly to the server"""
Expand Down
7 changes: 7 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import inspect
import json
import os
import tempfile
import time
import tracemalloc
import unittest
Expand Down Expand Up @@ -46,6 +47,12 @@ def setUp(self):
# Make it easy to grab test data, regardless of where the test is
self.datadir = os.path.join(os.path.dirname(__file__), "data")

def make_tempdir(self) -> str:
"""Creates a temporary dir that will be automatically cleaned up"""
tempdir = tempfile.TemporaryDirectory() # pylint: disable=consider-using-with
self.addCleanup(tempdir.cleanup)
return tempdir.name

def patch(self, *args, **kwargs) -> mock.Mock:
"""Syntactic sugar to ease making a mock over a test's lifecycle, without decorators"""
patcher = mock.patch(*args, **kwargs)
Expand Down

0 comments on commit 494998f

Please sign in to comment.