Skip to content

Commit

Permalink
Fix indexing so dimensions are never squeezed so they are persisted i…
Browse files Browse the repository at this point in the history
…n response
  • Loading branch information
mpiannucci committed Nov 14, 2024
1 parent bf9d5d9 commit 18bd668
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 12 deletions.
31 changes: 28 additions & 3 deletions tests/test_cf_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,9 +93,8 @@ def test_cf_position_query(cf_client, cf_dataset):

assert air_range["type"] == "NdArray", "Response range should be a NdArray"
assert air_range["dataType"] == "float", "Air dataType should be floats"
assert air_range["axisNames"] == ["t"], "Time should be the only remaining axes"
assert len(air_range["shape"]) == 1, "There should only one axes"
assert air_range["shape"][0] == len(axes["t"]["values"]), "The shape of the "
assert air_range["axisNames"] == ["t", "y", "x"], "All dimensions should persist"
assert air_range["shape"] == [4, 1, 1], "The shape of the array should be 4x1x1"
assert (
len(air_range["values"]) == 4
), "There should be 4 values, one for each time step"
Expand Down Expand Up @@ -129,6 +128,32 @@ def test_cf_position_csv(cf_client):
for key in ("time", "lat", "lon", "air", "cell_area"):
assert key in csv_data[0], f"column {key} should be in the header"

# single time step test
response = cf_client.get(
f"/datasets/air/edr/position?coords=POINT({x} {y})&f=csv&parameter-name=air&datetime=2013-01-01T00:00:00",
)

assert response.status_code == 200, "Response did not return successfully"
assert (
"text/csv" in response.headers["content-type"]
), "The content type should be set as a CSV"
assert (
"attachment" in response.headers["content-disposition"]
), "The response should be set as an attachment to trigger download"
assert (
"position.csv" in response.headers["content-disposition"]
), "The file name should be position.csv"

csv_data = [
line.split(",") for line in response.content.decode("utf-8").splitlines()
]

assert (
len(csv_data) == 2
), "There should be 2 data rows, one data and one header row"
for key in ("time", "lat", "lon", "air", "cell_area"):
assert key in csv_data[0], f"column {key} should be in the header"


def test_cf_position_csv_interpolate(cf_client):
x = 204
Expand Down
1 change: 0 additions & 1 deletion xpublish_edr/formats/to_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@

def to_csv(ds: xr.Dataset):
"""Return a CSV response from an xarray dataset"""
ds = ds.squeeze()
df = ds.to_dataframe()

csv = df.to_csv()
Expand Down
1 change: 0 additions & 1 deletion xpublish_edr/formats/to_geojson.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ def handle_date_columns(df: pd.DataFrame) -> pd.DataFrame:

def to_geojson(ds: xr.Dataset):
"""Return a GeoJSON response from an xarray dataset"""
ds = ds.squeeze()
axes = ds.cf.axes
(x_col,) = axes["X"]
(y_col,) = axes["Y"]
Expand Down
4 changes: 2 additions & 2 deletions xpublish_edr/geometry/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,9 @@ def _select_by_position_regular_xy_grid(
"""
# Find the nearest X and Y coordinates to the point
if method == "nearest":
return ds.cf.sel(X=point.x, Y=point.y, method=method)
return ds.cf.sel(X=[point.x], Y=[point.y], method=method)
else:
return ds.cf.interp(X=point.x, Y=point.y, method=method)
return ds.cf.interp(X=[point.x], Y=[point.y], method=method)


def _select_by_multiple_positions_regular_xy_grid(
Expand Down
11 changes: 6 additions & 5 deletions xpublish_edr/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,18 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset:
"""Select data from a dataset based on the query"""
if self.z:
if self.method == "nearest":
ds = ds.cf.sel(Z=self.z, method=self.method)
ds = ds.cf.sel(Z=[self.z], method=self.method)
else:
ds = ds.cf.interp(Z=self.z, method=self.method)
ds = ds.cf.interp(Z=[self.z], method=self.method)

if self.datetime:
try:
datetimes = self.datetime.split("/")
if len(datetimes) == 1:
if self.method == "nearest":
ds = ds.cf.sel(T=datetimes[0], method=self.method)
ds = ds.cf.sel(T=datetimes, method=self.method)
else:
ds = ds.cf.interp(T=datetimes[0], method=self.method)
ds = ds.cf.interp(T=datetimes, method=self.method)
elif len(datetimes) == 2:
ds = ds.cf.sel(T=slice(datetimes[0], datetimes[1]))
else:
Expand All @@ -74,7 +74,7 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset:
for key, value in query_params.items():
split_value = value.split("/")
if len(split_value) == 1:
continue
query_params[key] = [split_value[0]]
elif len(split_value) == 2:
query_params[key] = slice(split_value[0], split_value[1])
else:
Expand All @@ -90,6 +90,7 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset:
except Exception as e:
logger.warning(f"Interpolation failed: {e}, falling back to selection")
ds = ds.sel(query_params, method="nearest")

return ds


Expand Down

0 comments on commit 18bd668

Please sign in to comment.