Skip to content

Commit

Permalink
Add ECSV reader (#288)
Browse files Browse the repository at this point in the history
* Add ECSV reader

* Test schema of ECSV-generated dataset
  • Loading branch information
delucchi-cmu authored May 3, 2024
1 parent d67409b commit 7a5c17c
Show file tree
Hide file tree
Showing 3 changed files with 522 additions and 1 deletion.
26 changes: 25 additions & 1 deletion src/hipscat_import/catalog/file_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from typing import Any, Dict, Union

import pyarrow.parquet as pq
from astropy.io import ascii as ascii_reader
from astropy.table import Table
from hipscat.io import FilePointer, file_io

Expand Down Expand Up @@ -38,14 +39,16 @@ def get_file_reader(
skip_column_names (list[str]): for fits files, a list of columns to remove.
type_map (dict): for CSV files, the data types to use for columns
"""
if "csv" in file_format:
if file_format == "csv":
return CsvReader(
chunksize=chunksize,
schema_file=schema_file,
column_names=column_names,
type_map=type_map,
**kwargs,
)
if file_format == "ecsv":
return AstropyEcsvReader(**kwargs)
if file_format == "fits":
return FitsReader(
chunksize=chunksize,
Expand Down Expand Up @@ -180,6 +183,27 @@ def provenance_info(self) -> dict:
return provenance_info


class AstropyEcsvReader(InputReader):
"""Reads astropy ascii .ecsv files.
Note that this is NOT a chunked reader. Use caution when reading
large ECSV files with this reader."""

def __init__(self, **kwargs):
self.kwargs = kwargs

def read(self, input_file, read_columns=None):
self.regular_file_exists(input_file, **self.kwargs)
if read_columns:
self.kwargs["include_names"] = read_columns

astropy_table = ascii_reader.read(input_file, format="ecsv", **self.kwargs)
yield astropy_table.to_pandas()

def provenance_info(self):
return {"input_reader_type": "AstropyEcsvReader"}


class FitsReader(InputReader):
"""Chunked FITS file reader.
Expand Down
114 changes: 114 additions & 0 deletions tests/hipscat_import/catalog/test_run_round_trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import numpy.testing as npt
import pandas as pd
import pyarrow as pa
import pyarrow.dataset as pds
import pyarrow.parquet as pq
import pytest
from hipscat.catalog.catalog import Catalog
Expand Down Expand Up @@ -477,3 +478,116 @@ def test_import_gaia_minimum(
assert "Norder" in column_names
assert "Dir" in column_names
assert "Npix" in column_names


@pytest.mark.dask
def test_gaia_ecsv(
dask_client,
formats_dir,
tmp_path,
assert_parquet_file_ids,
):
input_file = os.path.join(formats_dir, "gaia_epoch.ecsv")

args = ImportArguments(
output_artifact_name="gaia_e_astropy",
input_file_list=[input_file],
file_reader="ecsv",
ra_column="ra",
dec_column="dec",
sort_columns="solution_id,source_id",
output_path=tmp_path,
dask_tmp=tmp_path,
highest_healpix_order=2,
pixel_threshold=3_000,
progress_bar=False,
)

runner.run(args, dask_client)

# Check that the catalog metadata file exists
catalog = Catalog.read_from_hipscat(args.catalog_path)
assert catalog.on_disk
assert catalog.catalog_path == args.catalog_path
assert catalog.catalog_info.total_rows == 3
assert len(catalog.get_healpix_pixels()) == 1

output_file = os.path.join(args.catalog_path, "Norder=0", "Dir=0", "Npix=0.parquet")

assert_parquet_file_ids(output_file, "source_id", [10655814178816, 10892037246720, 14263587225600])

# Check that the schema is correct for leaf parquet and _metadata files
expected_parquet_schema = pa.schema(
[
pa.field("solution_id", pa.int64()),
pa.field("source_id", pa.int64()),
pa.field("ra", pa.float64()),
pa.field("dec", pa.float64()),
pa.field("n_transits", pa.int16()),
pa.field("transit_id", pa.list_(pa.int64())),
pa.field("g_transit_time", pa.list_(pa.float64())),
pa.field("g_transit_flux", pa.list_(pa.float64())),
pa.field("g_transit_flux_error", pa.list_(pa.float64())),
pa.field("g_transit_flux_over_error", pa.list_(pa.float32())),
pa.field("g_transit_mag", pa.list_(pa.float64())),
pa.field("g_transit_n_obs", pa.list_(pa.int8())),
pa.field("bp_obs_time", pa.list_(pa.float64())),
pa.field("bp_flux", pa.list_(pa.float64())),
pa.field("bp_flux_error", pa.list_(pa.float64())),
pa.field("bp_flux_over_error", pa.list_(pa.float32())),
pa.field("bp_mag", pa.list_(pa.float64())),
pa.field("rp_obs_time", pa.list_(pa.float64())),
pa.field("rp_flux", pa.list_(pa.float64())),
pa.field("rp_flux_error", pa.list_(pa.float64())),
pa.field("rp_flux_over_error", pa.list_(pa.float32())),
pa.field("rp_mag", pa.list_(pa.float64())),
pa.field("photometry_flag_noisy_data", pa.list_(pa.bool_())),
pa.field("photometry_flag_sm_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af1_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af2_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af3_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af4_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af5_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af6_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af7_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af8_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_af9_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_bp_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_rp_unavailable", pa.list_(pa.bool_())),
pa.field("photometry_flag_sm_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af1_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af2_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af3_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af4_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af5_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af6_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af7_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af8_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_af9_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_bp_reject", pa.list_(pa.bool_())),
pa.field("photometry_flag_rp_reject", pa.list_(pa.bool_())),
pa.field("variability_flag_g_reject", pa.list_(pa.bool_())),
pa.field("variability_flag_bp_reject", pa.list_(pa.bool_())),
pa.field("variability_flag_rp_reject", pa.list_(pa.bool_())),
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint64()),
pa.field("Npix", pa.uint64()),
pa.field("_hipscat_index", pa.uint64()),
]
)

# In-memory schema uses list<item> naming convention, but pyarrow converts to
# the parquet-compliant list<element> convention when writing to disk.
# Round trip the schema to get a schema with compliant nested naming convention.
schema_path = os.path.join(tmp_path, "temp_schema.parquet")
pq.write_table(expected_parquet_schema.empty_table(), where=schema_path)
expected_parquet_schema = pq.read_metadata(schema_path).schema.to_arrow_schema()

schema = pq.read_metadata(output_file).schema.to_arrow_schema()
assert schema.equals(expected_parquet_schema, check_metadata=False)
schema = pq.read_metadata(os.path.join(args.catalog_path, "_metadata")).schema.to_arrow_schema()
assert schema.equals(expected_parquet_schema, check_metadata=False)
schema = pq.read_metadata(os.path.join(args.catalog_path, "_common_metadata")).schema.to_arrow_schema()
assert schema.equals(expected_parquet_schema, check_metadata=False)
schema = pds.dataset(args.catalog_path, format="parquet").schema
assert schema.equals(expected_parquet_schema, check_metadata=False)
Loading

0 comments on commit 7a5c17c

Please sign in to comment.