Skip to content

Commit

Permalink
Support reading a file to pyarrow table in catalog import (#415)
Browse files Browse the repository at this point in the history
  • Loading branch information
delucchi-cmu authored Oct 21, 2024
1 parent aeb92ae commit bdee662
Show file tree
Hide file tree
Showing 30 changed files with 271 additions and 94 deletions.
30 changes: 29 additions & 1 deletion src/hats_import/catalog/file_readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import abc

import pandas as pd
import pyarrow
import pyarrow as pa
import pyarrow.dataset
import pyarrow.parquet as pq
from astropy.io import ascii as ascii_reader
Expand Down Expand Up @@ -356,6 +356,34 @@ def read(self, input_file, read_columns=None):
yield smaller_table.to_pandas()


class ParquetPyarrowReader(InputReader):
"""Parquet reader for the most common Parquet reading arguments.
Attributes:
chunksize (int): number of rows of the file to process at once.
For large files, this can prevent loading the entire file
into memory at once.
column_names (list[str] or None): Names of columns to use from the input dataset.
If None, use all columns.
kwargs: arguments to pass along to pyarrow.parquet.ParquetFile.
See https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetFile.html
"""

def __init__(self, chunksize=500_000, column_names=None, **kwargs):
self.chunksize = chunksize
self.column_names = column_names
self.kwargs = kwargs

def read(self, input_file, read_columns=None):
self.regular_file_exists(input_file, **self.kwargs)
columns = read_columns or self.column_names
parquet_file = pq.ParquetFile(input_file, **self.kwargs)
for smaller_table in parquet_file.iter_batches(batch_size=self.chunksize, columns=columns):
table = pa.Table.from_batches([smaller_table])
table = table.replace_schema_metadata()
yield table


class IndexedParquetReader(InputReader):
"""Reads an index file, containing paths to parquet files to be read and batched
Expand Down
115 changes: 53 additions & 62 deletions src/hats_import/catalog/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import hats.pixel_math.healpix_shim as hp
import numpy as np
import pandas as pd
import pyarrow as pa
import pyarrow.parquet as pq
from hats import pixel_math
Expand Down Expand Up @@ -50,21 +51,30 @@ def _iterate_input_file(

for chunk_number, data in enumerate(file_reader.read(input_file, read_columns=read_columns)):
if use_healpix_29:
if data.index.name == SPATIAL_INDEX_COLUMN:
if isinstance(data, pd.DataFrame) and data.index.name == SPATIAL_INDEX_COLUMN:
mapped_pixels = spatial_index_to_healpix(data.index, target_order=highest_order)
else:
mapped_pixels = spatial_index_to_healpix(
data[SPATIAL_INDEX_COLUMN], target_order=highest_order
)
else:
# Set up the pixel data
mapped_pixels = hp.ang2pix(
2**highest_order,
data[ra_column].to_numpy(copy=False, dtype=float),
data[dec_column].to_numpy(copy=False, dtype=float),
lonlat=True,
nest=True,
)
if isinstance(data, pd.DataFrame):
mapped_pixels = hp.ang2pix(
2**highest_order,
data[ra_column].to_numpy(copy=False, dtype=float),
data[dec_column].to_numpy(copy=False, dtype=float),
lonlat=True,
nest=True,
)
else:
mapped_pixels = hp.ang2pix(
2**highest_order,
data[ra_column].to_numpy(),
data[dec_column].to_numpy(),
lonlat=True,
nest=True,
)
yield chunk_number, data, mapped_pixels


Expand Down Expand Up @@ -168,17 +178,20 @@ def split_pixels(
unique_pixels, unique_inverse = np.unique(aligned_pixels, return_inverse=True)

for unique_index, [order, pixel, _] in enumerate(unique_pixels):
filtered_data = data.iloc[unique_inverse == unique_index]

pixel_dir = get_pixel_cache_directory(cache_shard_path, HealpixPixel(order, pixel))
file_io.make_directory(pixel_dir, exist_ok=True)
output_file = file_io.append_paths_to_pointer(
pixel_dir, f"shard_{splitting_key}_{chunk_number}.parquet"
)
if _has_named_index(filtered_data):
filtered_data.to_parquet(output_file.path, index=True, filesystem=output_file.fs)
if isinstance(data, pd.DataFrame):
filtered_data = data.iloc[unique_inverse == unique_index]
if _has_named_index(filtered_data):
filtered_data = filtered_data.reset_index()
filtered_data = pa.Table.from_pandas(filtered_data, preserve_index=False)
else:
filtered_data.to_parquet(output_file.path, index=False, filesystem=output_file.fs)
filtered_data = data.filter(unique_inverse == unique_index)

pq.write_table(filtered_data, output_file.path, filesystem=output_file.fs)
del filtered_data

ResumePlan.splitting_key_done(tmp_path=resume_path, splitting_key=splitting_key)
Expand Down Expand Up @@ -258,15 +271,10 @@ def reduce_pixel_shards(
if use_schema_file:
schema = file_io.read_parquet_metadata(use_schema_file).schema.to_arrow_schema()

tables = []
healpix_pixel = HealpixPixel(destination_pixel_order, destination_pixel_number)
pixel_dir = get_pixel_cache_directory(cache_shard_path, healpix_pixel)

if schema:
tables.append(pq.read_table(pixel_dir, schema=schema))
else:
tables.append(pq.read_table(pixel_dir))

merged_table = pa.concat_tables(tables)
merged_table = pq.read_table(pixel_dir, schema=schema)

rows_written = len(merged_table)

Expand All @@ -277,38 +285,36 @@ def reduce_pixel_shards(
f" Expected {destination_pixel_size}, wrote {rows_written}"
)

dataframe = merged_table.to_pandas()
if sort_columns:
dataframe = dataframe.sort_values(sort_columns.split(","), kind="stable")
split_columns = sort_columns.split(",")
if len(split_columns) > 1:
merged_table = merged_table.sort_by([(col_name, "ascending") for col_name in split_columns])
else:
merged_table = merged_table.sort_by(sort_columns)
if add_healpix_29:
## If we had a meaningful index before, preserve it as a column.
if _has_named_index(dataframe):
dataframe = dataframe.reset_index()

dataframe[SPATIAL_INDEX_COLUMN] = pixel_math.compute_spatial_index(
dataframe[ra_column].values,
dataframe[dec_column].values,
)
dataframe = dataframe.set_index(SPATIAL_INDEX_COLUMN).sort_index(kind="stable")

# Adjust the schema to make sure that the _healpix_29 will
# be saved as a uint64
merged_table = merged_table.add_column(
0,
SPATIAL_INDEX_COLUMN,
[
pixel_math.compute_spatial_index(
merged_table[ra_column].to_numpy(),
merged_table[dec_column].to_numpy(),
)
],
).sort_by(SPATIAL_INDEX_COLUMN)
elif use_healpix_29:
if dataframe.index.name != SPATIAL_INDEX_COLUMN:
dataframe = dataframe.set_index(SPATIAL_INDEX_COLUMN)
dataframe = dataframe.sort_index(kind="stable")
merged_table = merged_table.sort_by(SPATIAL_INDEX_COLUMN)

dataframe["Norder"] = np.full(rows_written, fill_value=healpix_pixel.order, dtype=np.uint8)
dataframe["Dir"] = np.full(rows_written, fill_value=healpix_pixel.dir, dtype=np.uint64)
dataframe["Npix"] = np.full(rows_written, fill_value=healpix_pixel.pixel, dtype=np.uint64)

if schema:
schema = _modify_arrow_schema(schema, add_healpix_29)
dataframe.to_parquet(destination_file.path, schema=schema, filesystem=destination_file.fs)
else:
dataframe.to_parquet(destination_file.path, filesystem=destination_file.fs)
merged_table = (
merged_table.append_column(
"Norder", [np.full(rows_written, fill_value=healpix_pixel.order, dtype=np.uint8)]
)
.append_column("Dir", [np.full(rows_written, fill_value=healpix_pixel.dir, dtype=np.uint64)])
.append_column("Npix", [np.full(rows_written, fill_value=healpix_pixel.pixel, dtype=np.uint64)])
)

del dataframe, merged_table, tables
pq.write_table(merged_table, destination_file.path, filesystem=destination_file.fs)
del merged_table

if delete_input_files:
pixel_dir = get_pixel_cache_directory(cache_shard_path, healpix_pixel)
Expand All @@ -322,18 +328,3 @@ def reduce_pixel_shards(
exception,
)
raise exception


def _modify_arrow_schema(schema, add_healpix_29):
if add_healpix_29:
pandas_index_column = schema.get_field_index("__index_level_0__")
if pandas_index_column != -1:
schema = schema.remove(pandas_index_column)
schema = schema.insert(0, pa.field(SPATIAL_INDEX_COLUMN, pa.int64()))
schema = (
schema.append(pa.field("Norder", pa.uint8()))
.append(pa.field("Dir", pa.uint64()))
.append(pa.field("Npix", pa.uint64()))
)

return schema
7 changes: 5 additions & 2 deletions src/hats_import/index/map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@ def read_leaf_file(input_file, include_columns, include_healpix_29, drop_duplica
schema=schema,
)

data = data.reset_index()
if not include_healpix_29:
if data.index.name == SPATIAL_INDEX_COLUMN:
data = data.reset_index()
if not include_healpix_29 and SPATIAL_INDEX_COLUMN in data.columns:
data = data.drop(columns=[SPATIAL_INDEX_COLUMN])

if drop_duplicates:
Expand All @@ -32,6 +33,8 @@ def create_index(args, client):
include_columns = [args.indexing_column]
if args.extra_columns:
include_columns.extend(args.extra_columns)
if args.include_healpix_29:
include_columns.append(SPATIAL_INDEX_COLUMN)
if args.include_order_pixel:
include_columns.extend(["Norder", "Dir", "Npix"])

Expand Down
Binary file not shown.
Binary file modified tests/data/small_sky_object_catalog/dataset/_common_metadata
Binary file not shown.
Binary file modified tests/data/small_sky_object_catalog/dataset/_metadata
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/data/small_sky_object_catalog/properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ hats_col_dec=dec
hats_max_rows=1000000
hats_order=0
moc_sky_fraction=0.08333
hats_builder=hats-import v0.3.6.dev26+g40366b4
hats_creation_date=2024-10-11T15\:02UTC
hats_estsize=74
hats_builder=hats-import v0.4.1.dev2+gaeb92ae
hats_creation_date=2024-10-21T13\:22UTC
hats_estsize=70
hats_release_date=2024-09-18
hats_version=v0.1
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file modified tests/data/small_sky_source_catalog/dataset/_common_metadata
Binary file not shown.
Binary file modified tests/data/small_sky_source_catalog/dataset/_metadata
Binary file not shown.
6 changes: 3 additions & 3 deletions tests/data/small_sky_source_catalog/properties
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ hats_col_dec=source_dec
hats_max_rows=3000
hats_order=2
moc_sky_fraction=0.16667
hats_builder=hats-import v0.3.6.dev26+g40366b4
hats_creation_date=2024-10-11T15\:02UTC
hats_estsize=1105
hats_builder=hats-import v0.4.1.dev2+gaeb92ae
hats_creation_date=2024-10-21T13\:22UTC
hats_estsize=1083
hats_release_date=2024-09-18
hats_version=v0.1
3 changes: 1 addition & 2 deletions tests/hats_import/catalog/test_map_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,10 +353,9 @@ def test_reduce_healpix_29(parquet_shards_dir, assert_parquet_file_ids, tmp_path
expected_ids = [*range(700, 831)]
assert_parquet_file_ids(output_file, "id", expected_ids)
data_frame = pd.read_parquet(output_file, engine="pyarrow")
assert data_frame.index.name == "_healpix_29"
npt.assert_array_equal(
data_frame.columns,
["id", "ra", "dec", "ra_error", "dec_error", "Norder", "Dir", "Npix"],
["_healpix_29", "id", "ra", "dec", "ra_error", "dec_error", "Norder", "Dir", "Npix"],
)

mr.reduce_pixel_shards(
Expand Down
3 changes: 2 additions & 1 deletion tests/hats_import/catalog/test_run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@ def test_dask_runner(
# Check that the schema is correct for leaf parquet and _metadata files
expected_parquet_schema = pa.schema(
[
pa.field("_healpix_29", pa.int64()),
pa.field("id", pa.int64()),
pa.field("ra", pa.float32()),
pa.field("dec", pa.float32()),
Expand All @@ -286,7 +287,6 @@ def test_dask_runner(
pa.field("Norder", pa.uint8()),
pa.field("Dir", pa.uint64()),
pa.field("Npix", pa.uint64()),
pa.field("_healpix_29", pa.int64()),
]
)
schema = pq.read_metadata(output_file).schema.to_arrow_schema()
Expand All @@ -298,6 +298,7 @@ def test_dask_runner(
data_frame = pd.read_parquet(output_file, engine="pyarrow")
expected_dtypes = pd.Series(
{
"_healpix_29": np.int64,
"id": np.int64,
"ra": np.float32,
"dec": np.float32,
Expand Down
Loading

0 comments on commit bdee662

Please sign in to comment.