Skip to content

Commit

Permalink
Quick fixes (no warning, allow . in catalog name). (#307)
Browse files Browse the repository at this point in the history
* Quick fixes (no warning, allow . in catalog name).

* Address pylint warnings.
  • Loading branch information
delucchi-cmu authored May 14, 2024
1 parent e6a1ade commit 1555ca2
Show file tree
Hide file tree
Showing 7 changed files with 76 additions and 91 deletions.
5 changes: 3 additions & 2 deletions src/hipscat_import/catalog/file_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ def __init__(
self.parquet_kwargs = parquet_kwargs
self.kwargs = kwargs

schema_parquet = None
if self.schema_file:
if self.parquet_kwargs is None:
self.parquet_kwargs = {}
Expand All @@ -146,12 +147,12 @@ def __init__(

if self.column_names:
self.kwargs["names"] = self.column_names
elif not self.header and self.schema_file:
elif not self.header and schema_parquet is not None:
self.kwargs["names"] = schema_parquet.columns

if self.type_map:
self.kwargs["dtype"] = self.type_map
elif self.schema_file:
elif schema_parquet is not None:
self.kwargs["dtype"] = schema_parquet.dtypes.to_dict()

def read(self, input_file, read_columns=None):
Expand Down
5 changes: 1 addition & 4 deletions src/hipscat_import/pipeline_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import re
import warnings
from dataclasses import dataclass
from pathlib import Path

Expand Down Expand Up @@ -38,9 +37,7 @@ def safe_to_resume(self):
if not self.resume:
self.clean_resume_files()
else:
warnings.warn(
f"tmp_path ({self.tmp_path}) contains intermediate files; resuming prior progress."
)
print(f"tmp_path ({self.tmp_path}) contains intermediate files; resuming prior progress.")
file_io.make_directory(self.tmp_path, exist_ok=True)

def done_file_exists(self, stage_name):
Expand Down
3 changes: 2 additions & 1 deletion src/hipscat_import/runtime_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def _check_arguments(self):
raise ValueError("output_path is required")
if not self.output_artifact_name:
raise ValueError("output_artifact_name is required")
if re.search(r"[^A-Za-z0-9_\-\\]", self.output_artifact_name):
if re.search(r"[^A-Za-z0-9\._\-\\]", self.output_artifact_name):
raise ValueError("output_artifact_name contains invalid characters")

if self.dask_n_workers <= 0:
Expand Down Expand Up @@ -145,6 +145,7 @@ def find_input_paths(
Raises:
FileNotFoundError: if no files are found at the input_path and the provided list is empty.
"""
input_paths = []
if input_path:
if not file_io.does_file_or_directory_exist(input_path, storage_options=storage_options):
raise FileNotFoundError("input_path not found on local storage")
Expand Down
60 changes: 26 additions & 34 deletions tests/hipscat_import/catalog/test_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,19 @@ def test_done_checks(tmp_path):
plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, resume=True)
plan.touch_stage_done_file(ResumePlan.REDUCING_STAGE)

with pytest.warns(UserWarning, match="resuming prior progress"):
with pytest.raises(ValueError, match="before reducing"):
plan.gather_plan()
with pytest.raises(ValueError, match="before reducing"):
plan.gather_plan()

plan.touch_stage_done_file(ResumePlan.SPLITTING_STAGE)
with pytest.warns(UserWarning, match="resuming prior progress"):
with pytest.raises(ValueError, match="before reducing"):
plan.gather_plan()
with pytest.raises(ValueError, match="before reducing"):
plan.gather_plan()

plan.clean_resume_files()

plan = ResumePlan(tmp_path=tmp_path, progress_bar=False, resume=True)
plan.touch_stage_done_file(ResumePlan.SPLITTING_STAGE)
with pytest.warns(UserWarning, match="resuming prior progress"):
with pytest.raises(ValueError, match="before splitting"):
plan.gather_plan()
with pytest.raises(ValueError, match="before splitting"):
plan.gather_plan()


def test_same_input_paths(tmp_path, small_sky_single_file, formats_headers_csv):
Expand All @@ -66,33 +63,30 @@ def test_same_input_paths(tmp_path, small_sky_single_file, formats_headers_csv):
map_files = plan.map_files
assert len(map_files) == 2

with pytest.warns(UserWarning, match="resuming prior progress"):
with pytest.raises(ValueError, match="Different file set"):
ResumePlan(
tmp_path=tmp_path,
progress_bar=False,
resume=True,
input_paths=[small_sky_single_file],
)
with pytest.raises(ValueError, match="Different file set"):
ResumePlan(
tmp_path=tmp_path,
progress_bar=False,
resume=True,
input_paths=[small_sky_single_file],
)

## List is the same length, but includes a duplicate
with pytest.warns(UserWarning, match="resuming prior progress"):
with pytest.raises(ValueError, match="Different file set"):
ResumePlan(
tmp_path=tmp_path,
progress_bar=False,
resume=True,
input_paths=[small_sky_single_file, small_sky_single_file],
)

## Includes a duplicate file, but that's ok.
with pytest.warns(UserWarning, match="resuming prior progress"):
plan = ResumePlan(
with pytest.raises(ValueError, match="Different file set"):
ResumePlan(
tmp_path=tmp_path,
progress_bar=False,
resume=True,
input_paths=[small_sky_single_file, small_sky_single_file, formats_headers_csv],
input_paths=[small_sky_single_file, small_sky_single_file],
)

## Includes a duplicate file, but that's ok.
plan = ResumePlan(
tmp_path=tmp_path,
progress_bar=False,
resume=True,
input_paths=[small_sky_single_file, small_sky_single_file, formats_headers_csv],
)
map_files = plan.map_files
assert len(map_files) == 2

Expand Down Expand Up @@ -150,15 +144,13 @@ def test_read_write_splitting_keys(tmp_path, small_sky_single_file, formats_head

ResumePlan.touch_key_done_file(tmp_path, ResumePlan.SPLITTING_STAGE, "split_0")

with pytest.warns(UserWarning, match="resuming prior progress"):
plan.gather_plan()
plan.gather_plan()
split_keys = plan.split_keys
assert len(split_keys) == 1
assert split_keys[0][0] == "split_1"

ResumePlan.touch_key_done_file(tmp_path, ResumePlan.SPLITTING_STAGE, "split_1")
with pytest.warns(UserWarning, match="resuming prior progress"):
plan.gather_plan()
plan.gather_plan()
split_keys = plan.split_keys
assert len(split_keys) == 0

Expand Down
83 changes: 40 additions & 43 deletions tests/hipscat_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,18 @@ def test_resume_dask_runner(
os.path.join(tmp_path, "resume_catalog", "Norder=0"),
)

with pytest.warns(UserWarning, match="resuming prior progress"):
args = ImportArguments(
output_artifact_name="resume_catalog",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
tmp_dir=tmp_path,
resume_tmp=os.path.join(tmp_path, "tmp"),
highest_healpix_order=0,
pixel_threshold=1000,
progress_bar=False,
)
args = ImportArguments(
output_artifact_name="resume_catalog",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
tmp_dir=tmp_path,
resume_tmp=os.path.join(tmp_path, "tmp"),
highest_healpix_order=0,
pixel_threshold=1000,
progress_bar=False,
)

runner.run(args, dask_client)

Expand Down Expand Up @@ -166,21 +165,20 @@ def test_resume_dask_runner_diff_pixel_order(
os.path.join(tmp_path, "resume_catalog", "Norder=0"),
)

with pytest.warns(UserWarning, match="resuming prior progress"):
with pytest.raises(ValueError, match="incompatible with the highest healpix order"):
args = ImportArguments(
output_artifact_name="resume_catalog",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
tmp_dir=tmp_path,
resume_tmp=os.path.join(tmp_path, "tmp"),
constant_healpix_order=1,
pixel_threshold=1000,
progress_bar=False,
)
runner.run(args, dask_client)
with pytest.raises(ValueError, match="incompatible with the highest healpix order"):
args = ImportArguments(
output_artifact_name="resume_catalog",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
tmp_dir=tmp_path,
resume_tmp=os.path.join(tmp_path, "tmp"),
constant_healpix_order=1,
pixel_threshold=1000,
progress_bar=False,
)
runner.run(args, dask_client)

# Running with resume set to "False" will start the pipeline from scratch
args = ImportArguments(
Expand Down Expand Up @@ -240,21 +238,20 @@ def test_resume_dask_runner_histograms_diff_size(
else:
wrong_histogram.to_file(histogram_file)

with pytest.warns(UserWarning, match="resuming prior progress"):
with pytest.raises(ValueError, match="histogram partials have incompatible sizes"):
args = ImportArguments(
output_artifact_name="resume_catalog",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
tmp_dir=tmp_path,
resume_tmp=os.path.join(tmp_path, "tmp"),
constant_healpix_order=1,
pixel_threshold=1000,
progress_bar=False,
)
runner.run(args, dask_client)
with pytest.raises(ValueError, match="histogram partials have incompatible sizes"):
args = ImportArguments(
output_artifact_name="resume_catalog",
input_path=small_sky_parts_dir,
file_reader="csv",
output_path=tmp_path,
dask_tmp=tmp_path,
tmp_dir=tmp_path,
resume_tmp=os.path.join(tmp_path, "tmp"),
constant_healpix_order=1,
pixel_threshold=1000,
progress_bar=False,
)
runner.run(args, dask_client)


@pytest.mark.dask
Expand Down
2 changes: 1 addition & 1 deletion tests/hipscat_import/catalog/test_run_round_trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def test_import_source_table(
- will have larger partition info than the corresponding object catalog
"""
args = ImportArguments(
output_artifact_name="small_sky_source_catalog",
output_artifact_name="small_sky_source_catalog.parquet",
input_path=small_sky_source_dir,
file_reader="csv",
catalog_type="source",
Expand Down
9 changes: 3 additions & 6 deletions tests/hipscat_import/soap/test_soap_resume_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,15 +94,13 @@ def test_count_keys(small_sky_soap_args):
## Mark one done and check that there's one less key to count later.
Path(small_sky_soap_args.tmp_path, "2_187.csv").touch()

with pytest.warns(UserWarning, match="resuming prior progress"):
plan.gather_plan(small_sky_soap_args)
plan.gather_plan(small_sky_soap_args)
assert len(plan.count_keys) == 13

## Mark them ALL done and check that there are on keys later.
plan.touch_stage_done_file(SoapPlan.COUNTING_STAGE)

with pytest.warns(UserWarning, match="resuming prior progress"):
plan.gather_plan(small_sky_soap_args)
plan.gather_plan(small_sky_soap_args)
assert len(plan.count_keys) == 0


Expand All @@ -117,8 +115,7 @@ def test_cached_map_file(small_sky_soap_args):
cache_map_file = os.path.join(small_sky_soap_args.tmp_path, SoapPlan.SOURCE_MAP_FILE)
assert os.path.exists(cache_map_file)

with pytest.warns(UserWarning, match="resuming prior progress"):
plan = SoapPlan(small_sky_soap_args)
plan = SoapPlan(small_sky_soap_args)
assert len(plan.count_keys) == 14


Expand Down

0 comments on commit 1555ca2

Please sign in to comment.