Skip to content

Commit

Permalink
Round trip tests & various fixes (#42)
Browse files Browse the repository at this point in the history
Round trip tests for pyarrow conversion.
  • Loading branch information
kylebarron authored Apr 22, 2024
1 parent 49f678f commit fd7c9a4
Show file tree
Hide file tree
Showing 17 changed files with 24,273 additions and 49 deletions.
52 changes: 39 additions & 13 deletions stac_geoparquet/from_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,46 @@ def _convert_bbox_to_array(table: pa.Table) -> pa.Table:
new_chunks = []
for chunk in bbox_col.chunks:
assert pa.types.is_struct(chunk.type)
xmin = chunk.field(0).to_numpy()
ymin = chunk.field(1).to_numpy()
xmax = chunk.field(2).to_numpy()
ymax = chunk.field(3).to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)
if bbox_col.type.num_fields == 4:
xmin = chunk.field("xmin").to_numpy()
ymin = chunk.field("ymin").to_numpy()
xmax = chunk.field("xmax").to_numpy()
ymax = chunk.field("ymax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
xmax,
ymax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 4)

elif bbox_col.type.num_fields == 6:
xmin = chunk.field("xmin").to_numpy()
ymin = chunk.field("ymin").to_numpy()
zmin = chunk.field("zmin").to_numpy()
xmax = chunk.field("xmax").to_numpy()
ymax = chunk.field("ymax").to_numpy()
zmax = chunk.field("zmax").to_numpy()
coords = np.column_stack(
[
xmin,
ymin,
zmin,
xmax,
ymax,
zmax,
]
)

list_arr = pa.FixedSizeListArray.from_arrays(coords.flatten("C"), 6)

else:
raise ValueError("Expected 4 or 6 fields in bbox struct.")

new_chunks.append(list_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)
149 changes: 113 additions & 36 deletions stac_geoparquet/to_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def parse_stac_items_to_arrow(
*,
chunk_size: int = 8192,
schema: Optional[pa.Schema] = None,
downcast: bool = True,
) -> pa.Table:
"""Parse a collection of STAC Items to a :class:`pyarrow.Table`.
Expand All @@ -41,6 +42,7 @@ def parse_stac_items_to_arrow(
schema: The schema of the input data. If provided, can improve memory use;
otherwise all items need to be parsed into a single array for schema
inference. Defaults to None.
downcast: if True, store bbox as float32 for memory and disk saving.
Returns:
a pyarrow Table with the STAC-GeoParquet representation of items.
Expand All @@ -53,22 +55,23 @@ def parse_stac_items_to_arrow(
for chunk in _chunks(items, chunk_size):
batches.append(_stac_items_to_arrow(chunk, schema=schema))

stac_table = pa.Table.from_batches(batches, schema=schema)
table = pa.Table.from_batches(batches, schema=schema)
else:
# If schema is _not_ provided, then we must convert to Arrow all at once, or
# else it would be possible for a STAC item late in the collection (after the
# first chunk) to have a different schema and not match the schema inferred for
# the first chunk.
stac_table = pa.Table.from_batches([_stac_items_to_arrow(items)])
table = pa.Table.from_batches([_stac_items_to_arrow(items)])

return _process_arrow_table(stac_table)
return _process_arrow_table(table, downcast=downcast)


def parse_stac_ndjson_to_arrow(
path: Union[str, Path],
*,
chunk_size: int = 8192,
schema: Optional[pa.Schema] = None,
downcast: bool = True,
) -> pa.Table:
# Define outside of if/else to make mypy happy
items: List[dict] = []
Expand Down Expand Up @@ -98,14 +101,14 @@ def parse_stac_ndjson_to_arrow(
if len(items) > 0:
batches.append(_stac_items_to_arrow(items, schema=schema))

stac_table = pa.Table.from_batches(batches, schema=schema)
return _process_arrow_table(stac_table)
table = pa.Table.from_batches(batches, schema=schema)
return _process_arrow_table(table, downcast=downcast)


def _process_arrow_table(table: pa.Table) -> pa.Table:
def _process_arrow_table(table: pa.Table, *, downcast: bool = True) -> pa.Table:
table = _bring_properties_to_top_level(table)
table = _convert_timestamp_columns(table)
table = _convert_bbox_to_struct(table)
table = _convert_bbox_to_struct(table, downcast=downcast)
return table


Expand Down Expand Up @@ -192,11 +195,21 @@ def _convert_timestamp_columns(table: pa.Table) -> pa.Table:
except KeyError:
continue

field_index = table.schema.get_field_index(column_name)

if pa.types.is_timestamp(column.type):
continue

# STAC allows datetimes to be null. If all rows are null, the column type may be
# inferred as null. We cast this to a timestamp column.
elif pa.types.is_null(column.type):
table = table.set_column(
field_index, column_name, column.cast(pa.timestamp("us"))
)

elif pa.types.is_string(column.type):
table = table.drop(column_name).append_column(
column_name, _convert_timestamp_column(column)
table = table.set_column(
field_index, column_name, _convert_timestamp_column(column)
)
else:
raise ValueError(
Expand Down Expand Up @@ -224,7 +237,26 @@ def _convert_timestamp_column(column: pa.ChunkedArray) -> pa.ChunkedArray:
return pa.chunked_array(chunks)


def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Table:
def is_bbox_3d(bbox_col: pa.ChunkedArray) -> bool:
"""Infer whether the bounding box column represents 2d or 3d bounding boxes."""
offsets_set = set()
for chunk in bbox_col.chunks:
offsets = chunk.offsets.to_numpy()
offsets_set.update(np.unique(offsets[1:] - offsets[:-1]))

if len(offsets_set) > 1:
raise ValueError("Mixed 2d-3d bounding boxes not yet supported")

offset = list(offsets_set)[0]
if offset == 6:
return True
elif offset == 4:
return False
else:
raise ValueError(f"Unexpected bbox offset: {offset=}")


def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool) -> pa.Table:
"""Convert bbox column to a struct representation
Since the bbox in JSON is stored as an array, pyarrow automatically converts the
Expand All @@ -244,6 +276,7 @@ def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Tab
"""
bbox_col_idx = table.schema.get_field_index("bbox")
bbox_col = table.column(bbox_col_idx)
bbox_3d = is_bbox_3d(bbox_col)

new_chunks = []
for chunk in bbox_col.chunks:
Expand All @@ -252,36 +285,80 @@ def _convert_bbox_to_struct(table: pa.Table, *, downcast: bool = True) -> pa.Tab
or pa.types.is_large_list(chunk.type)
or pa.types.is_fixed_size_list(chunk.type)
)
coords = chunk.flatten().to_numpy().reshape(-1, 4)
xmin = coords[:, 0]
ymin = coords[:, 1]
xmax = coords[:, 2]
ymax = coords[:, 3]
if bbox_3d:
coords = chunk.flatten().to_numpy().reshape(-1, 6)
else:
coords = chunk.flatten().to_numpy().reshape(-1, 4)

if downcast:
coords = coords.astype(np.float32)

# Round min values down to the next float32 value
# Round max values up to the next float32 value
xmin = np.nextafter(xmin, -np.Infinity)
ymin = np.nextafter(ymin, -np.Infinity)
xmax = np.nextafter(xmax, np.Infinity)
ymax = np.nextafter(ymax, np.Infinity)

struct_arr = pa.StructArray.from_arrays(
[
xmin,
ymin,
xmax,
ymax,
],
names=[
"xmin",
"ymin",
"xmax",
"ymax",
],
)
if bbox_3d:
xmin = coords[:, 0]
ymin = coords[:, 1]
zmin = coords[:, 2]
xmax = coords[:, 3]
ymax = coords[:, 4]
zmax = coords[:, 5]

if downcast:
# Round min values down to the next float32 value
# Round max values up to the next float32 value
xmin = np.nextafter(xmin, -np.Infinity)
ymin = np.nextafter(ymin, -np.Infinity)
zmin = np.nextafter(zmin, -np.Infinity)
xmax = np.nextafter(xmax, np.Infinity)
ymax = np.nextafter(ymax, np.Infinity)
zmax = np.nextafter(zmax, np.Infinity)

struct_arr = pa.StructArray.from_arrays(
[
xmin,
ymin,
zmin,
xmax,
ymax,
zmax,
],
names=[
"xmin",
"ymin",
"zmin",
"xmax",
"ymax",
"zmax",
],
)

else:
xmin = coords[:, 0]
ymin = coords[:, 1]
xmax = coords[:, 2]
ymax = coords[:, 3]

if downcast:
# Round min values down to the next float32 value
# Round max values up to the next float32 value
xmin = np.nextafter(xmin, -np.Infinity)
ymin = np.nextafter(ymin, -np.Infinity)
xmax = np.nextafter(xmax, np.Infinity)
ymax = np.nextafter(ymax, np.Infinity)

struct_arr = pa.StructArray.from_arrays(
[
xmin,
ymin,
xmax,
ymax,
],
names=[
"xmin",
"ymin",
"xmax",
"ymax",
],
)

new_chunks.append(struct_arr)

return table.set_column(bbox_col_idx, "bbox", new_chunks)
Loading

0 comments on commit fd7c9a4

Please sign in to comment.