diff --git a/README.md b/README.md index b9e76e2..d5d09fa 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ In the future, when `xpublish` supports [`DataTree`](https://docs.xarray.dev/en/ | `z` | ✅ | | | `datetime` | ✅ | | | `parameter-name` | ✅ | | -| `crs` | ❌ | Not currently supported, all coordinates should be in the reference system of the queried dataset | +| `crs` | ✅ | Requires a CF compliant [grid mapping](https://cf-xarray.readthedocs.io/en/latest/grid_mappings.html) on the target dataset. Default is `EPSG:4326` | | `parameter-name` | ✅ | | | `f` | ✅ | | | `method` | ➕ | Optional: controls data selection. Use "nearest" for nearest neighbor selection, or "linear" for interpolated selection. Uses `nearest` if not specified | @@ -84,7 +84,7 @@ In the future, when `xpublish` supports [`DataTree`](https://docs.xarray.dev/en/ | `z` | ✅ | | | `datetime` | ✅ | | | `parameter-name` | ✅ | | -| `crs` | ❌ | Not currently supported, all coordinates should be in the reference system of the queried dataset | +| `crs` | ✅ | Requires a CF compliant [grid mapping](https://cf-xarray.readthedocs.io/en/latest/grid_mappings.html) on the target dataset. Default is `EPSG:4326` | | `parameter-name` | ✅ | | | `f` | ✅ | | | `method` | ➕ | Optional: controls data selection. Use "nearest" for nearest neighbor selection, or "linear" for interpolated selection. Uses `nearest` if not specified | diff --git a/tests/test_cf_router.py b/tests/test_cf_router.py index 756d2cd..a7a8041 100644 --- a/tests/test_cf_router.py +++ b/tests/test_cf_router.py @@ -7,15 +7,25 @@ @pytest.fixture(scope="session") -def cf_dataset(): +def cf_air_dataset(): from cf_xarray.datasets import airds return airds @pytest.fixture(scope="session") -def cf_xpublish(cf_dataset): - rest = xpublish.Rest({"air": cf_dataset}, plugins={"edr": CfEdrPlugin()}) +def cf_temp_dataset(): + from cf_xarray.datasets import rotds + + return rotds + + +@pytest.fixture(scope="session") +def cf_xpublish(cf_air_dataset, cf_temp_dataset): + rest = xpublish.Rest( + {"air": cf_air_dataset, "temp": cf_temp_dataset}, + plugins={"edr": CfEdrPlugin()}, + ) return rest @@ -52,7 +62,7 @@ def test_cf_area_formats(cf_client): assert "csv" in data, "csv is not a valid format" -def test_cf_position_query(cf_client, cf_dataset): +def test_cf_position_query(cf_client, cf_air_dataset, cf_temp_dataset): x = 204 y = 44 response = cf_client.get(f"/datasets/air/edr/position?coords=POINT({x} {y})") @@ -76,17 +86,18 @@ def test_cf_position_query(cf_client, cf_dataset): air_param = data["parameters"]["air"] assert ( - air_param["unit"]["label"]["en"] == cf_dataset["air"].attrs["units"] + air_param["unit"]["label"]["en"] == cf_air_dataset["air"].attrs["units"] ), "DataArray units should be set as parameter units" assert ( - air_param["observedProperty"]["id"] == cf_dataset["air"].attrs["standard_name"] + air_param["observedProperty"]["id"] + == cf_air_dataset["air"].attrs["standard_name"] ), "DataArray standard_name should be set as the observed property id" assert ( air_param["observedProperty"]["label"]["en"] - == cf_dataset["air"].attrs["long_name"] + == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter observed property" assert ( - air_param["description"]["en"] == cf_dataset["air"].attrs["long_name"] + air_param["description"]["en"] == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter description" air_range = data["ranges"]["air"] @@ -99,6 +110,34 @@ def test_cf_position_query(cf_client, cf_dataset): len(air_range["values"]) == 4 ), "There should be 4 values, one for each time step" + # Test with a dataset containing data in a different coordinate system + x = 64.59063409 + y = 66.66454929 + response = cf_client.get(f"/datasets/temp/edr/position?coords=POINT({x} {y})") + assert response.status_code == 200, "Response did not return successfully" + + data = response.json() + for key in ("type", "domain", "parameters", "ranges"): + assert key in data, f"Key {key} is not a top level key in the CovJSON response" + + axes = data["domain"]["axes"] + + npt.assert_array_almost_equal( + axes["longitude"]["values"], + [[x]], + ), "Did not select nearby x coordinate" + npt.assert_array_almost_equal( + axes["latitude"]["values"], + [[y]], + ), "Did not select a nearby y coordinate" + + temp_range = data["ranges"]["temp"] + assert temp_range["type"] == "NdArray", "Response range should be a NdArray" + assert temp_range["dataType"] == "float", "Air dataType should be floats" + assert temp_range["axisNames"] == ["rlat", "rlon"], "All dimensions should persist" + assert temp_range["shape"] == [1, 1], "The shape of the array should be 1x1" + assert len(temp_range["values"]) == 1, "There should be 1 value selected" + def test_cf_position_csv(cf_client): x = 204 @@ -346,7 +385,7 @@ def test_cf_multiple_position_csv(cf_client): assert key in csv_data[0], f"column {key} should be in the header" -def test_cf_area_query(cf_client, cf_dataset): +def test_cf_area_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=cf_covjson") @@ -372,17 +411,18 @@ def test_cf_area_query(cf_client, cf_dataset): air_param = data["parameters"]["air"] assert ( - air_param["unit"]["label"]["en"] == cf_dataset["air"].attrs["units"] + air_param["unit"]["label"]["en"] == cf_air_dataset["air"].attrs["units"] ), "DataArray units should be set as parameter units" assert ( - air_param["observedProperty"]["id"] == cf_dataset["air"].attrs["standard_name"] + air_param["observedProperty"]["id"] + == cf_air_dataset["air"].attrs["standard_name"] ), "DataArray standard_name should be set as the observed property id" assert ( air_param["observedProperty"]["label"]["en"] - == cf_dataset["air"].attrs["long_name"] + == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter observed property" assert ( - air_param["description"]["en"] == cf_dataset["air"].attrs["long_name"] + air_param["description"]["en"] == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter description" air_range = data["ranges"]["air"] @@ -403,7 +443,7 @@ def test_cf_area_query(cf_client, cf_dataset): ), "There should be 26 values, 9 for each time step" -def test_cf_area_csv_query(cf_client, cf_dataset): +def test_cf_area_csv_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=csv") @@ -427,7 +467,7 @@ def test_cf_area_csv_query(cf_client, cf_dataset): assert key in csv_data[0], f"column {key} should be in the header" -def test_cf_area_geojson_query(cf_client, cf_dataset): +def test_cf_area_geojson_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=geojson") @@ -446,7 +486,7 @@ def test_cf_area_geojson_query(cf_client, cf_dataset): assert len(features) == 36, "There should be 36 data points" -def test_cf_area_nc_query(cf_client, cf_dataset): +def test_cf_area_nc_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=nc") diff --git a/tests/test_select.py b/tests/test_select.py index 28ead14..e54f41d 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -4,9 +4,11 @@ import pandas as pd import pytest import xarray as xr +import xarray.testing as xrt from shapely import MultiPoint, Point, from_wkt from xpublish_edr.geometry.area import select_by_area +from xpublish_edr.geometry.common import project_dataset from xpublish_edr.geometry.position import select_by_position from xpublish_edr.query import EDRQuery @@ -17,6 +19,14 @@ def regular_xy_dataset(): return xr.tutorial.load_dataset("air_temperature") +@pytest.fixture(scope="function") +def projected_xy_dataset(): + """Loads a sample dataset with projected X and Y coordinates""" + from cf_xarray.datasets import rotds + + return rotds + + def test_select_query(regular_xy_dataset): query = EDRQuery( coords="POINT(200 45)", @@ -134,6 +144,44 @@ def test_select_position_regular_xy(regular_xy_dataset): npt.assert_approx_equal(ds["air"][-1], 279.19), "Temperature is incorrect" +def test_select_position_projected_xy(projected_xy_dataset): + query = EDRQuery( + coords="POINT(64.59063409 66.66454929)", + crs="EPSG:4326", + ) + + projected_point = query.project_geometry(projected_xy_dataset) + npt.assert_approx_equal(projected_point.x, 18.045), "Longitude is incorrect" + npt.assert_approx_equal(projected_point.y, 21.725), "Latitude is incorrect" + + ds = select_by_position(projected_xy_dataset, projected_point) + xrt.assert_identical( + ds, + projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest"), + ) + + projected_ds = project_dataset(ds, query.crs) + ( + npt.assert_approx_equal(projected_ds.longitude.values, 64.59063409), + "Longitude is incorrect", + ) + ( + npt.assert_approx_equal(projected_ds.latitude.values, 66.66454929), + "Latitude is incorrect", + ) + ( + npt.assert_approx_equal( + projected_ds.temp.values, + projected_xy_dataset.sel( + rlon=[18.045], + rlat=[21.725], + method="nearest", + ).temp.values, + ), + "Temperature is incorrect", + ) + + def test_select_position_regular_xy_interpolate(regular_xy_dataset): point = Point((204, 44)) ds = select_by_position(regular_xy_dataset, point, method="linear") @@ -170,6 +218,36 @@ def test_select_position_regular_xy_multi(regular_xy_dataset): ) +def test_select_position_projected_xy_multi(projected_xy_dataset): + query = EDRQuery( + coords="MULTIPOINT(64.3 66.6, 64.6 66.5)", + crs="EPSG:4326", + method="linear", + ) + + projected_points = query.project_geometry(projected_xy_dataset) + ds = select_by_position(projected_xy_dataset, projected_points, method="linear") + projected_ds = project_dataset(ds, query.crs) + assert "temp" in projected_ds, "Dataset does not contain the temp variable" + assert "rlon" not in projected_ds, "Dataset does not contain the rlon variable" + assert "rlat" not in projected_ds, "Dataset does not contain the rlat variable" + ( + npt.assert_array_almost_equal(projected_ds.longitude, [64.3, 64.6]), + "Longitude is incorrect", + ) + ( + npt.assert_array_almost_equal(projected_ds.latitude, [66.6, 66.5]), + "Latitude is incorrect", + ) + ( + npt.assert_array_almost_equal( + ds.temp, + projected_ds.temp, + ), + "Temperature is incorrect", + ) + + def test_select_position_regular_xy_multi_interpolate(regular_xy_dataset): points = MultiPoint([(202, 45), (205, 48)]) ds = select_by_position(regular_xy_dataset, points, method="linear") @@ -236,6 +314,28 @@ def test_select_area_regular_xy(regular_xy_dataset): ) +def test_select_area_projected_xy(projected_xy_dataset): + query = EDRQuery( + coords="POLYGON((64.3 66.82, 64.5 66.82, 64.5 66.6, 64.3 66.6, 64.3 66.82))", + crs="EPSG:4326", + ) + + projected_area = query.project_geometry(projected_xy_dataset) + ds = select_by_area(projected_xy_dataset, projected_area) + projected_ds = project_dataset(ds, query.crs) + + assert projected_ds is not None, "Dataset was not returned" + assert "temp" in projected_ds, "Dataset does not contain the air variable" + assert "latitude" in projected_ds, "Dataset does not contain the latitude variable" + assert ( + "longitude" in projected_ds + ), "Dataset does not contain the longitude variable" + + assert projected_ds.longitude.shape[0] == 1, "Longitude shape is incorrect" + assert projected_ds.latitude.shape[0] == 1, "Latitude shape is incorrect" + assert projected_ds.temp.shape[0] == 1, "Temperature shape is incorrect" + + def test_select_area_regular_xy_boundary(regular_xy_dataset): polygon = from_wkt("POLYGON((200 40, 200 50, 210 50, 210 40, 200 40))").buffer( 0.0001, diff --git a/xpublish_edr/geometry/area.py b/xpublish_edr/geometry/area.py index dd37d5a..d2a0a78 100644 --- a/xpublish_edr/geometry/area.py +++ b/xpublish_edr/geometry/area.py @@ -31,7 +31,16 @@ def _select_area_regular_xy_grid( """ # To minimize performance impact, we first subset the dataset to the bounding box of the polygon (minx, miny, maxx, maxy) = polygon.bounds - ds = ds.cf.sel(X=slice(minx, maxx), Y=slice(maxy, miny)) + indexes = ds.cf.indexes + if indexes["X"].is_monotonic_increasing: + x_sel = slice(minx, maxx) + else: + x_sel = slice(maxx, minx) + if indexes["Y"].is_monotonic_increasing: + y_sel = slice(miny, maxy) + else: + y_sel = slice(maxy, miny) + ds = ds.cf.sel(X=x_sel, Y=y_sel) # For a regular grid, we can create a meshgrid of the X and Y coordinates to create a spatial mask pts = np.meshgrid(ds.cf["X"], ds.cf["Y"]) @@ -45,6 +54,4 @@ def _select_area_regular_xy_grid( y_sel = xr.Variable(data=y_inds, dims=VECTORIZED_DIM) # Apply the mask and vectorize to a 1d collection of points - ds_sub = ds.cf.isel(X=x_sel, Y=y_sel) - - return ds_sub + return ds.cf.isel(X=x_sel, Y=y_sel) diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index d189c85..16368c9 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -2,10 +2,19 @@ Common geometry handling functions """ +import itertools +from functools import lru_cache + +import pyproj import xarray as xr +from shapely import Geometry +from shapely.ops import transform VECTORIZED_DIM = "pts" +# https://pyproj4.github.io/pyproj/stable/advanced_examples.html#caching-pyproj-objectshttps://pyproj4.github.io/pyproj/stable/advanced_examples.html#caching-pyproj-objects +transformer_from_crs = lru_cache(pyproj.Transformer.from_crs) + def coord_is_regular(da: xr.DataArray) -> bool: """ @@ -19,3 +28,103 @@ def is_regular_xy_coords(ds: xr.Dataset) -> bool: Check if the dataset has 2D coordinates """ return coord_is_regular(ds.cf["X"]) and coord_is_regular(ds.cf["Y"]) + + +def dataset_crs(ds: xr.Dataset) -> pyproj.CRS: + grid_mapping_names = ds.cf.grid_mapping_names + if len(grid_mapping_names) == 0: + # Default to WGS84 + keys = ds.cf.keys() + if "latitude" in keys and "longitude" in keys: + return pyproj.CRS.from_epsg(4326) + else: + raise ValueError("Unknown coordinate system") + if len(grid_mapping_names) > 1: + raise ValueError(f"Multiple grid mappings found: {grid_mapping_names!r}!") + (grid_mapping_var,) = tuple(itertools.chain(*ds.cf.grid_mapping_names.values())) + + grid_mapping = ds[grid_mapping_var] + return pyproj.CRS.from_cf(grid_mapping.attrs) + + +def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> Geometry: + """ + Get the projection from the dataset + """ + data_crs = dataset_crs(ds) + + transformer = transformer_from_crs( + crs_from=geometry_crs, + crs_to=data_crs, + always_xy=True, + ) + return transform(transformer.transform, geometry) + + +def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: + """ + Project the dataset to the given CRS + """ + data_crs = dataset_crs(ds) + target_crs = pyproj.CRS.from_string(query_crs) + if data_crs == target_crs: + return ds + + transformer = transformer_from_crs( + crs_from=data_crs, + crs_to=target_crs, + always_xy=True, + ) + + # TODO: Handle rotated pole + cf_coords = target_crs.coordinate_system.to_cf() + + # Get the new X and Y coordinates + target_y_coord = next(coord for coord in cf_coords if coord["axis"] == "Y") + target_x_coord = next(coord for coord in cf_coords if coord["axis"] == "X") + + X = ds.cf["X"] + Y = ds.cf["Y"] + + # Transform the coordinates + # If the data is vectorized, we just transform the points in full + # TODO: Handle 2D coordinates + if len(X.dims) > 1 or len(Y.dims) > 1: + raise NotImplementedError("Only 1D coordinates are supported") + + x, y = xr.broadcast(X, Y) + target_dims = x.dims + + x, y = transformer.transform(x, y) + + x_dim = X.dims[0] + y_dim = Y.dims[0] + + coords_to_drop = [ + c for c in ds.coords if x_dim in ds[c].dims or y_dim in ds[c].dims + ] + + target_x_coord_name = target_x_coord["standard_name"] + target_y_coord_name = target_y_coord["standard_name"] + + stdnames = ds.cf.standard_names + coords_to_drop += list( + itertools.chain( + stdnames.get(target_x_coord_name, []), + stdnames.get(target_y_coord_name, []), + ), + ) + ds = ds.drop_vars(coords_to_drop) + + # Create the new dataset with vectorized coordinates + ds = ds.assign_coords( + { + target_x_coord_name: (target_dims, x), + target_y_coord_name: (target_dims, y), + }, + ) + + if x_dim != y_dim: + ds = ds.transpose(..., y_dim, x_dim) + + return ds diff --git a/xpublish_edr/geometry/position.py b/xpublish_edr/geometry/position.py index 671409b..f995e06 100644 --- a/xpublish_edr/geometry/position.py +++ b/xpublish_edr/geometry/position.py @@ -60,8 +60,10 @@ def _select_by_multiple_positions_regular_xy_grid( """ # Find the nearest X and Y coordinates to the point using vectorized indexing x, y = np.array(list(zip(*[(point.x, point.y) for point in points.geoms]))) - sel_x = xr.Variable(data=x, dims=VECTORIZED_DIM) - sel_y = xr.Variable(data=y, dims=VECTORIZED_DIM) + + # When using vectorized indexing with interp, we need to persist the attributes explicitly + sel_x = xr.Variable(data=x, dims=VECTORIZED_DIM, attrs=ds.cf["X"].attrs) + sel_y = xr.Variable(data=y, dims=VECTORIZED_DIM, attrs=ds.cf["Y"].attrs) if method == "nearest": return ds.cf.sel(X=sel_x, Y=sel_y, method=method) else: diff --git a/xpublish_edr/plugin.py b/xpublish_edr/plugin.py index bb6d55a..fc8fd71 100644 --- a/xpublish_edr/plugin.py +++ b/xpublish_edr/plugin.py @@ -11,6 +11,7 @@ from xpublish_edr.formats.to_covjson import to_cf_covjson from xpublish_edr.geometry.area import select_by_area +from xpublish_edr.geometry.common import project_dataset from xpublish_edr.geometry.position import select_by_position from xpublish_edr.logger import logger from xpublish_edr.query import EDRQuery, edr_query @@ -101,7 +102,7 @@ def get_position( logger.debug(f"Dataset filtered by query params {ds}") try: - ds = select_by_position(ds, query.geometry, query.method) + ds = select_by_position(ds, query.project_geometry(ds), query.method) except KeyError: raise HTTPException( status_code=404, @@ -112,6 +113,17 @@ def get_position( f"Dataset filtered by position ({query.geometry}): {ds}", ) + try: + ds = project_dataset(ds, query.crs) + except Exception as e: + logger.error(f"Error projecting dataset: {e}") + raise HTTPException( + status_code=404, + detail="Error projecting dataset", + ) + + logger.debug(f"Dataset projected to {query.crs}: {ds}") + if query.format: try: format_fn = position_formats()[query.format] @@ -148,7 +160,7 @@ def get_area( logger.debug(f"Dataset filtered by query params {ds}") try: - ds = select_by_area(ds, query.geometry) + ds = select_by_area(ds, query.project_geometry(ds)) except KeyError: raise HTTPException( status_code=404, @@ -157,6 +169,17 @@ def get_area( logger.debug(f"Dataset filtered by polygon {query.geometry.boundary}: {ds}") + try: + ds = project_dataset(ds, query.crs) + except Exception as e: + logger.error(f"Error projecting dataset: {e}") + raise HTTPException( + status_code=404, + detail="Error projecting dataset", + ) + + logger.debug(f"Dataset projected to {query.crs}: {ds}") + if query.format: try: format_fn = position_formats()[query.format] diff --git a/xpublish_edr/query.py b/xpublish_edr/query.py index 8cde252..bee3d5c 100644 --- a/xpublish_edr/query.py +++ b/xpublish_edr/query.py @@ -7,8 +7,9 @@ import xarray as xr from fastapi import Query from pydantic import BaseModel, Field -from shapely import wkt +from shapely import Geometry, wkt +from xpublish_edr.geometry.common import project_geometry from xpublish_edr.logger import logger @@ -25,15 +26,24 @@ class EDRQuery(BaseModel): z: Optional[str] = None datetime: Optional[str] = None parameters: Optional[str] = None - crs: Optional[str] = None + crs: str = Field( + "EPSG:4326", + title="Coordinate Reference System", + description="Coordinate Reference System for the query. Default is EPSG:4326", + ) format: Optional[str] = None method: Literal["nearest", "linear"] = "nearest" @property - def geometry(self): + def geometry(self) -> Geometry: """Shapely point from WKT query params""" return wkt.loads(self.coords) + def project_geometry(self, ds: xr.Dataset) -> Geometry: + """Project the geometry to the dataset's CRS""" + geometry = self.geometry + return project_geometry(ds, self.crs, geometry) + def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset: """Select data from a dataset based on the query""" if self.z: @@ -118,8 +128,8 @@ def edr_query( alias="parameter-name", description="xarray variables to query", ), - crs: Optional[str] = Query( - None, + crs: str = Query( + "EPSG:4326", deprecated=True, description="CRS is not yet implemented", ),