Skip to content

Commit

Permalink
Fix indexing so dimensions are never squeezed so they are included in…
Browse files Browse the repository at this point in the history
… response (#59)

* Fix indexing so dimensions are never squeezed so they are persisted in response

* lint

* update tests
  • Loading branch information
mpiannucci authored Nov 14, 2024
1 parent bf9d5d9 commit a0cc958
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 15 deletions.
31 changes: 28 additions & 3 deletions tests/test_cf_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def test_cf_position_query(cf_client, cf_dataset):

assert air_range["type"] == "NdArray", "Response range should be a NdArray"
assert air_range["dataType"] == "float", "Air dataType should be floats"
assert air_range["axisNames"] == ["t"], "Time should be the only remaining axes"
assert len(air_range["shape"]) == 1, "There should only one axes"
assert air_range["shape"][0] == len(axes["t"]["values"]), "The shape of the "
assert air_range["axisNames"] == ["t", "y", "x"], "All dimensions should persist"
assert air_range["shape"] == [4, 1, 1], "The shape of the array should be 4x1x1"
assert (
len(air_range["values"]) == 4
), "There should be 4 values, one for each time step"
Expand Down Expand Up @@ -129,6 +128,32 @@ def test_cf_position_csv(cf_client):
for key in ("time", "lat", "lon", "air", "cell_area"):
assert key in csv_data[0], f"column {key} should be in the header"

# single time step test
response = cf_client.get(
f"/datasets/air/edr/position?coords=POINT({x} {y})&f=csv&parameter-name=air&datetime=2013-01-01T00:00:00", # noqa
)

assert response.status_code == 200, "Response did not return successfully"
assert (
"text/csv" in response.headers["content-type"]
), "The content type should be set as a CSV"
assert (
"attachment" in response.headers["content-disposition"]
), "The response should be set as an attachment to trigger download"
assert (
"position.csv" in response.headers["content-disposition"]
), "The file name should be position.csv"

csv_data = [
line.split(",") for line in response.content.decode("utf-8").splitlines()
]

assert (
len(csv_data) == 2
), "There should be 2 data rows, one data and one header row"
for key in ("time", "lat", "lon", "air", "cell_area"):
assert key in csv_data[0], f"column {key} should be in the header"


def test_cf_position_csv_interpolate(cf_client):
x = 204
Expand Down
6 changes: 3 additions & 3 deletions tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def test_select_query(regular_xy_dataset):
assert ds["time"] == pd.to_datetime(
"2013-01-01T06:00:00",
), "Dataset shape is incorrect"
assert ds["air"].shape == (25, 53), "Dataset shape is incorrect"
assert ds["air"].shape == (1, 25, 53), "Dataset shape is incorrect"

query = EDRQuery(
coords="POINT(200 45)",
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_select_position_regular_xy(regular_xy_dataset):
assert "lat" in ds, "Dataset does not contain the lat variable"
assert "lon" in ds, "Dataset does not contain the lon variable"

assert ds["air"].shape == ds["time"].shape, "Dataset shape is incorrect"
assert ds["air"].shape == (2920, 1, 1), "Dataset shape is incorrect"
npt.assert_array_equal(ds["lat"], 45.0), "Latitude is incorrect"
npt.assert_array_equal(ds["lon"], 205.0), "Longitude is incorrect"
npt.assert_approx_equal(ds["air"][0], 280.2), "Temperature is incorrect"
Expand All @@ -143,7 +143,7 @@ def test_select_position_regular_xy_interpolate(regular_xy_dataset):
assert "lat" in ds, "Dataset does not contain the lat variable"
assert "lon" in ds, "Dataset does not contain the lon variable"

assert ds["air"].shape == ds["time"].shape, "Dataset shape is incorrect"
assert ds["air"].shape == (2920, 1, 1), "Dataset shape is incorrect"
npt.assert_array_equal(ds["lat"], 44.0), "Latitude is incorrect"
npt.assert_array_equal(ds["lon"], 204.0), "Longitude is incorrect"
npt.assert_approx_equal(ds["air"][0], 281.376), "Temperature is incorrect"
Expand Down
1 change: 0 additions & 1 deletion xpublish_edr/formats/to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

def to_csv(ds: xr.Dataset):
"""Return a CSV response from an xarray dataset"""
ds = ds.squeeze()
df = ds.to_dataframe()

csv = df.to_csv()
Expand Down
1 change: 0 additions & 1 deletion xpublish_edr/formats/to_geojson.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def handle_date_columns(df: pd.DataFrame) -> pd.DataFrame:

def to_geojson(ds: xr.Dataset):
"""Return a GeoJSON response from an xarray dataset"""
ds = ds.squeeze()
axes = ds.cf.axes
(x_col,) = axes["X"]
(y_col,) = axes["Y"]
Expand Down
4 changes: 2 additions & 2 deletions xpublish_edr/geometry/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def _select_by_position_regular_xy_grid(
"""
# Find the nearest X and Y coordinates to the point
if method == "nearest":
return ds.cf.sel(X=point.x, Y=point.y, method=method)
return ds.cf.sel(X=[point.x], Y=[point.y], method=method)
else:
return ds.cf.interp(X=point.x, Y=point.y, method=method)
return ds.cf.interp(X=[point.x], Y=[point.y], method=method)


def _select_by_multiple_positions_regular_xy_grid(
Expand Down
11 changes: 6 additions & 5 deletions xpublish_edr/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset:
"""Select data from a dataset based on the query"""
if self.z:
if self.method == "nearest":
ds = ds.cf.sel(Z=self.z, method=self.method)
ds = ds.cf.sel(Z=[self.z], method=self.method)
else:
ds = ds.cf.interp(Z=self.z, method=self.method)
ds = ds.cf.interp(Z=[self.z], method=self.method)

if self.datetime:
try:
datetimes = self.datetime.split("/")
if len(datetimes) == 1:
if self.method == "nearest":
ds = ds.cf.sel(T=datetimes[0], method=self.method)
ds = ds.cf.sel(T=datetimes, method=self.method)
else:
ds = ds.cf.interp(T=datetimes[0], method=self.method)
ds = ds.cf.interp(T=datetimes, method=self.method)
elif len(datetimes) == 2:
ds = ds.cf.sel(T=slice(datetimes[0], datetimes[1]))
else:
Expand All @@ -74,7 +74,7 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset:
for key, value in query_params.items():
split_value = value.split("/")
if len(split_value) == 1:
continue
query_params[key] = [split_value[0]]
elif len(split_value) == 2:
query_params[key] = slice(split_value[0], split_value[1])
else:
Expand All @@ -90,6 +90,7 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset:
except Exception as e:
logger.warning(f"Interpolation failed: {e}, falling back to selection")
ds = ds.sel(query_params, method="nearest")

return ds


Expand Down

0 comments on commit a0cc958

Please sign in to comment.