Skip to content

Commit

Permalink
Add batched tests (#73)
Browse files Browse the repository at this point in the history
  • Loading branch information
kylebarron authored Jun 25, 2024
1 parent 5af1294 commit fd51c9c
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 10 deletions.
4 changes: 4 additions & 0 deletions stac_geoparquet/arrow/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,10 @@ def stac_table_to_ndjson(
) -> None:
"""Write STAC Arrow to a newline-delimited JSON file.
!!! note
This function _appends_ to the JSON file at `dest`; it does not overwrite any
existing data.
Args:
table: STAC in Arrow form. This can be a pyarrow Table, a pyarrow
RecordBatchReader, or any other Arrow stream object exposed through the
Expand Down
38 changes: 28 additions & 10 deletions tests/test_arrow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import json
from io import BytesIO
from pathlib import Path
Expand All @@ -7,6 +8,7 @@
import pytest

from stac_geoparquet.arrow import (
DEFAULT_JSON_CHUNK_SIZE,
parse_stac_items_to_arrow,
parse_stac_ndjson_to_arrow,
stac_table_to_items,
Expand All @@ -33,38 +35,54 @@
"us-census",
]

CHUNK_SIZES = [2, DEFAULT_JSON_CHUNK_SIZE]

@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS)
def test_round_trip_read_write(collection_id: str):

@pytest.mark.parametrize(
"collection_id,chunk_size", itertools.product(TEST_COLLECTIONS, CHUNK_SIZES)
)
def test_round_trip_read_write(collection_id: str, chunk_size: int):
with open(HERE / "data" / f"{collection_id}-pc.json") as f:
items = json.load(f)

table = pa.Table.from_batches(parse_stac_items_to_arrow(items))
table = parse_stac_items_to_arrow(items, chunk_size=chunk_size).read_all()
items_result = list(stac_table_to_items(table))

for result, expected in zip(items_result, items):
assert_json_value_equal(result, expected, precision=0)


@pytest.mark.parametrize("collection_id", TEST_COLLECTIONS)
def test_round_trip_write_read_ndjson(collection_id: str, tmp_path: Path):
@pytest.mark.parametrize(
"collection_id,chunk_size", itertools.product(TEST_COLLECTIONS, CHUNK_SIZES)
)
def test_round_trip_write_read_ndjson(
collection_id: str, chunk_size: int, tmp_path: Path
):
# First load into a STAC-GeoParquet table
path = HERE / "data" / f"{collection_id}-pc.json"
table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(path))
table = parse_stac_ndjson_to_arrow(path, chunk_size=chunk_size).read_all()

# Then write to disk
stac_table_to_ndjson(table, tmp_path / "tmp.ndjson")

# Then read back and assert tables match
table = pa.Table.from_batches(parse_stac_ndjson_to_arrow(tmp_path / "tmp.ndjson"))
with open(path) as f:
orig_json = json.load(f)

rt_json = []
with open(tmp_path / "tmp.ndjson") as f:
for line in f:
rt_json.append(json.loads(line))

# Then read back and assert JSON data matches
assert_json_value_equal(orig_json, rt_json, precision=0)


def test_table_contains_geoarrow_metadata():
collection_id = "naip"
with open(HERE / "data" / f"{collection_id}-pc.json") as f:
items = json.load(f)

table = pa.Table.from_batches(parse_stac_items_to_arrow(items))
table = parse_stac_items_to_arrow(items).read_all()
field_meta = table.schema.field("geometry").metadata
assert field_meta[b"ARROW:extension:name"] == b"geoarrow.wkb"
assert json.loads(field_meta[b"ARROW:extension:metadata"])["crs"]["id"] == {
Expand Down Expand Up @@ -107,7 +125,7 @@ def test_to_parquet_two_geometry_columns():
with open(HERE / "data" / "3dep-lidar-copc-pc.json") as f:
items = json.load(f)

table = pa.Table.from_batches(parse_stac_items_to_arrow(items))
table = parse_stac_items_to_arrow(items).read_all()
with BytesIO() as bio:
to_parquet(table, bio)
bio.seek(0)
Expand Down

0 comments on commit fd51c9c

Please sign in to comment.