From 18bd66841a89ebe17d54247ca283b6e32e434cf9 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 09:12:10 -0500 Subject: [PATCH] Fix indexing so dimensions are never squeezed so they are persisted in response --- tests/test_cf_router.py | 31 +++++++++++++++++++++++++++--- xpublish_edr/formats/to_csv.py | 1 - xpublish_edr/formats/to_geojson.py | 1 - xpublish_edr/geometry/position.py | 4 ++-- xpublish_edr/query.py | 11 ++++++----- 5 files changed, 36 insertions(+), 12 deletions(-) diff --git a/tests/test_cf_router.py b/tests/test_cf_router.py index b117112..6147fab 100644 --- a/tests/test_cf_router.py +++ b/tests/test_cf_router.py @@ -93,9 +93,8 @@ def test_cf_position_query(cf_client, cf_dataset): assert air_range["type"] == "NdArray", "Response range should be a NdArray" assert air_range["dataType"] == "float", "Air dataType should be floats" - assert air_range["axisNames"] == ["t"], "Time should be the only remaining axes" - assert len(air_range["shape"]) == 1, "There should only one axes" - assert air_range["shape"][0] == len(axes["t"]["values"]), "The shape of the " + assert air_range["axisNames"] == ["t", "y", "x"], "All dimensions should persist" + assert air_range["shape"] == [4, 1, 1], "The shape of the array should be 4x1x1" assert ( len(air_range["values"]) == 4 ), "There should be 4 values, one for each time step" @@ -129,6 +128,32 @@ def test_cf_position_csv(cf_client): for key in ("time", "lat", "lon", "air", "cell_area"): assert key in csv_data[0], f"column {key} should be in the header" + # single time step test + response = cf_client.get( + f"/datasets/air/edr/position?coords=POINT({x} {y})&f=csv¶meter-name=air&datetime=2013-01-01T00:00:00", + ) + + assert response.status_code == 200, "Response did not return successfully" + assert ( + "text/csv" in response.headers["content-type"] + ), "The content type should be set as a CSV" + assert ( + "attachment" in response.headers["content-disposition"] + ), "The response should be set as an attachment to trigger download" + assert ( + "position.csv" in response.headers["content-disposition"] + ), "The file name should be position.csv" + + csv_data = [ + line.split(",") for line in response.content.decode("utf-8").splitlines() + ] + + assert ( + len(csv_data) == 2 + ), "There should be 2 data rows, one data and one header row" + for key in ("time", "lat", "lon", "air", "cell_area"): + assert key in csv_data[0], f"column {key} should be in the header" + def test_cf_position_csv_interpolate(cf_client): x = 204 diff --git a/xpublish_edr/formats/to_csv.py b/xpublish_edr/formats/to_csv.py index 5dff27e..a2ebe5a 100644 --- a/xpublish_edr/formats/to_csv.py +++ b/xpublish_edr/formats/to_csv.py @@ -8,7 +8,6 @@ def to_csv(ds: xr.Dataset): """Return a CSV response from an xarray dataset""" - ds = ds.squeeze() df = ds.to_dataframe() csv = df.to_csv() diff --git a/xpublish_edr/formats/to_geojson.py b/xpublish_edr/formats/to_geojson.py index 06c5bea..d29443c 100644 --- a/xpublish_edr/formats/to_geojson.py +++ b/xpublish_edr/formats/to_geojson.py @@ -20,7 +20,6 @@ def handle_date_columns(df: pd.DataFrame) -> pd.DataFrame: def to_geojson(ds: xr.Dataset): """Return a GeoJSON response from an xarray dataset""" - ds = ds.squeeze() axes = ds.cf.axes (x_col,) = axes["X"] (y_col,) = axes["Y"] diff --git a/xpublish_edr/geometry/position.py b/xpublish_edr/geometry/position.py index e523db6..671409b 100644 --- a/xpublish_edr/geometry/position.py +++ b/xpublish_edr/geometry/position.py @@ -45,9 +45,9 @@ def _select_by_position_regular_xy_grid( """ # Find the nearest X and Y coordinates to the point if method == "nearest": - return ds.cf.sel(X=point.x, Y=point.y, method=method) + return ds.cf.sel(X=[point.x], Y=[point.y], method=method) else: - return ds.cf.interp(X=point.x, Y=point.y, method=method) + return ds.cf.interp(X=[point.x], Y=[point.y], method=method) def _select_by_multiple_positions_regular_xy_grid( diff --git a/xpublish_edr/query.py b/xpublish_edr/query.py index 9e882f0..8cde252 100644 --- a/xpublish_edr/query.py +++ b/xpublish_edr/query.py @@ -38,18 +38,18 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset: """Select data from a dataset based on the query""" if self.z: if self.method == "nearest": - ds = ds.cf.sel(Z=self.z, method=self.method) + ds = ds.cf.sel(Z=[self.z], method=self.method) else: - ds = ds.cf.interp(Z=self.z, method=self.method) + ds = ds.cf.interp(Z=[self.z], method=self.method) if self.datetime: try: datetimes = self.datetime.split("/") if len(datetimes) == 1: if self.method == "nearest": - ds = ds.cf.sel(T=datetimes[0], method=self.method) + ds = ds.cf.sel(T=datetimes, method=self.method) else: - ds = ds.cf.interp(T=datetimes[0], method=self.method) + ds = ds.cf.interp(T=datetimes, method=self.method) elif len(datetimes) == 2: ds = ds.cf.sel(T=slice(datetimes[0], datetimes[1])) else: @@ -74,7 +74,7 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset: for key, value in query_params.items(): split_value = value.split("/") if len(split_value) == 1: - continue + query_params[key] = [split_value[0]] elif len(split_value) == 2: query_params[key] = slice(split_value[0], split_value[1]) else: @@ -90,6 +90,7 @@ def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset: except Exception as e: logger.warning(f"Interpolation failed: {e}, falling back to selection") ds = ds.sel(query_params, method="nearest") + return ds