From 52573cee4b65c82d97df853469d5be3a2646de81 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Tue, 12 Nov 2024 09:15:10 -0500 Subject: [PATCH 01/21] Initial CRS support --- xpublish_edr/geometry/common.py | 36 +++++++++++++++++++++++++++++++++ xpublish_edr/plugin.py | 4 ++-- xpublish_edr/query.py | 20 +++++++++++++----- 3 files changed, 53 insertions(+), 7 deletions(-) diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index d189c85..30c1d48 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -2,10 +2,19 @@ Common geometry handling functions """ +import itertools +from functools import lru_cache + +import pyproj import xarray as xr +from shapely import Geometry +from shapely.ops import transform VECTORIZED_DIM = "pts" +# https://pyproj4.github.io/pyproj/stable/advanced_examples.html#caching-pyproj-objectshttps://pyproj4.github.io/pyproj/stable/advanced_examples.html#caching-pyproj-objects +transformer_from_crs = lru_cache(pyproj.Transformer.from_crs) + def coord_is_regular(da: xr.DataArray) -> bool: """ @@ -19,3 +28,30 @@ def is_regular_xy_coords(ds: xr.Dataset) -> bool: Check if the dataset has 2D coordinates """ return coord_is_regular(ds.cf["X"]) and coord_is_regular(ds.cf["Y"]) + + +def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> Geometry: + """ + Get the projection from the dataset + """ + grid_mapping_names = ds.cf.grid_mapping_names + if len(grid_mapping_names) == 0: + # TODO: Should we require a grid mapping? For now return as is + return geometry + if len(grid_mapping_names) > 1: + raise ValueError(f"Multiple grid mappings found: {grid_mapping_names!r}!") + (grid_mapping_var,) = tuple(itertools.chain(*ds.cf.grid_mapping_names.values())) + + grid_mapping = ds[grid_mapping_var] + data_crs = pyproj.crs.CRS.from_cf(grid_mapping.attrs) + if not data_crs.is_projected: + raise ValueError( + "This method is intended to be used with projected coordinate systems.", + ) + + transformer = transformer_from_crs( + crs_from=geometry_crs, + crs_to=data_crs, + always_xy=True, + ) + return transform(transformer.transform, geometry) diff --git a/xpublish_edr/plugin.py b/xpublish_edr/plugin.py index bb6d55a..c7866f1 100644 --- a/xpublish_edr/plugin.py +++ b/xpublish_edr/plugin.py @@ -101,7 +101,7 @@ def get_position( logger.debug(f"Dataset filtered by query params {ds}") try: - ds = select_by_position(ds, query.geometry, query.method) + ds = select_by_position(ds, query.project_geometry(ds), query.method) except KeyError: raise HTTPException( status_code=404, @@ -148,7 +148,7 @@ def get_area( logger.debug(f"Dataset filtered by query params {ds}") try: - ds = select_by_area(ds, query.geometry) + ds = select_by_area(ds, query.project_geometry(ds)) except KeyError: raise HTTPException( status_code=404, diff --git a/xpublish_edr/query.py b/xpublish_edr/query.py index 9e882f0..439f375 100644 --- a/xpublish_edr/query.py +++ b/xpublish_edr/query.py @@ -7,8 +7,9 @@ import xarray as xr from fastapi import Query from pydantic import BaseModel, Field -from shapely import wkt +from shapely import Geometry, wkt +from xpublish_edr.geometry.common import project_geometry from xpublish_edr.logger import logger @@ -25,15 +26,24 @@ class EDRQuery(BaseModel): z: Optional[str] = None datetime: Optional[str] = None parameters: Optional[str] = None - crs: Optional[str] = None + crs: str = Field( + "EPSG:4326", + title="Coordinate Reference System", + description="Coordinate Reference System for the query. Default is EPSG:4326", + ) format: Optional[str] = None method: Literal["nearest", "linear"] = "nearest" @property - def geometry(self): + def geometry(self) -> Geometry: """Shapely point from WKT query params""" return wkt.loads(self.coords) + def project_geometry(self, ds: xr.Dataset) -> Geometry: + """Project the geometry to the dataset's CRS""" + geometry = self.geometry + return project_geometry(ds, self.crs, geometry) + def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset: """Select data from a dataset based on the query""" if self.z: @@ -117,8 +127,8 @@ def edr_query( alias="parameter-name", description="xarray variables to query", ), - crs: Optional[str] = Query( - None, + crs: str = Query( + "EPSG:4326", deprecated=True, description="CRS is not yet implemented", ), From 4e6207af629fb7a71ab50b841466b494903c027d Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Tue, 12 Nov 2024 09:21:59 -0500 Subject: [PATCH 02/21] Update readme --- README.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b9e76e2..d5d09fa 100644 --- a/README.md +++ b/README.md @@ -69,7 +69,7 @@ In the future, when `xpublish` supports [`DataTree`](https://docs.xarray.dev/en/ | `z` | ✅ | | | `datetime` | ✅ | | | `parameter-name` | ✅ | | -| `crs` | ❌ | Not currently supported, all coordinates should be in the reference system of the queried dataset | +| `crs` | ✅ | Requires a CF compliant [grid mapping](https://cf-xarray.readthedocs.io/en/latest/grid_mappings.html) on the target dataset. Default is `EPSG:4326` | | `parameter-name` | ✅ | | | `f` | ✅ | | | `method` | ➕ | Optional: controls data selection. Use "nearest" for nearest neighbor selection, or "linear" for interpolated selection. Uses `nearest` if not specified | @@ -84,7 +84,7 @@ In the future, when `xpublish` supports [`DataTree`](https://docs.xarray.dev/en/ | `z` | ✅ | | | `datetime` | ✅ | | | `parameter-name` | ✅ | | -| `crs` | ❌ | Not currently supported, all coordinates should be in the reference system of the queried dataset | +| `crs` | ✅ | Requires a CF compliant [grid mapping](https://cf-xarray.readthedocs.io/en/latest/grid_mappings.html) on the target dataset. Default is `EPSG:4326` | | `parameter-name` | ✅ | | | `f` | ✅ | | | `method` | ➕ | Optional: controls data selection. Use "nearest" for nearest neighbor selection, or "linear" for interpolated selection. Uses `nearest` if not specified | From 331beaeb5f53e5f9a03bfa228410260b89625e38 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 10:45:01 -0500 Subject: [PATCH 03/21] IDK how this is failing... it works fine in notebook --- tests/test_select.py | 21 +++++++++++++++++++++ xpublish_edr/geometry/common.py | 4 ---- xpublish_edr/geometry/position.py | 2 +- 3 files changed, 22 insertions(+), 5 deletions(-) diff --git a/tests/test_select.py b/tests/test_select.py index 8c7ff98..491ef64 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -17,6 +17,13 @@ def regular_xy_dataset(): return xr.tutorial.load_dataset("air_temperature") +@pytest.fixture(scope="function") +def projected_xy_dataset(): + """Loads a sample dataset with projected X and Y coordinates""" + from cf_xarray.datasets import rotds + return rotds + + def test_select_query(regular_xy_dataset): query = EDRQuery( coords="POINT(200 45)", @@ -134,6 +141,20 @@ def test_select_position_regular_xy(regular_xy_dataset): npt.assert_approx_equal(ds["air"][-1], 279.19), "Temperature is incorrect" +def test_select_position_projected_xy(projected_xy_dataset): + from xpublish_edr.geometry.common import project_geometry + + point = Point((64.59063409, 66.66454929)) + projected_point = project_geometry(projected_xy_dataset, "EPSG:4326", point) + npt.assert_approx_equal(projected_point.x, 18.045), "Longitude is incorrect" + npt.assert_approx_equal(projected_point.y, 21.725), "Latitude is incorrect" + + ds = select_by_position(projected_xy_dataset, projected_point) + npt.assert_approx_equal(ds.rlon.values, 18.045), "Longitude is incorrect" + npt.assert_approx_equal(ds.rlat.values, 21.725), "Latitude is incorrect" + npt.assert_approx_equal(ds.temp.values, 0.89959461), "Temperature is incorrect" + + def test_select_position_regular_xy_interpolate(regular_xy_dataset): point = Point((204, 44)) ds = select_by_position(regular_xy_dataset, point, method="linear") diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index 30c1d48..a339617 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -44,10 +44,6 @@ def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> G grid_mapping = ds[grid_mapping_var] data_crs = pyproj.crs.CRS.from_cf(grid_mapping.attrs) - if not data_crs.is_projected: - raise ValueError( - "This method is intended to be used with projected coordinate systems.", - ) transformer = transformer_from_crs( crs_from=geometry_crs, diff --git a/xpublish_edr/geometry/position.py b/xpublish_edr/geometry/position.py index e523db6..ef4eef5 100644 --- a/xpublish_edr/geometry/position.py +++ b/xpublish_edr/geometry/position.py @@ -45,7 +45,7 @@ def _select_by_position_regular_xy_grid( """ # Find the nearest X and Y coordinates to the point if method == "nearest": - return ds.cf.sel(X=point.x, Y=point.y, method=method) + return ds.cf.sel(X=[point.x], Y=[point.y], method=method) else: return ds.cf.interp(X=point.x, Y=point.y, method=method) From 7dc8dfc83cb0301135cf4ffcfa36ccf9fe92360a Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 10:45:51 -0500 Subject: [PATCH 04/21] lint --- tests/test_select.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/test_select.py b/tests/test_select.py index 491ef64..04e801a 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -21,6 +21,7 @@ def regular_xy_dataset(): def projected_xy_dataset(): """Loads a sample dataset with projected X and Y coordinates""" from cf_xarray.datasets import rotds + return rotds From 7fbb3fcc0fe42dc576a9c9be10b28f7e76d34986 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 10:58:54 -0500 Subject: [PATCH 05/21] Fix test --- tests/test_select.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_select.py b/tests/test_select.py index 00b89b2..5043522 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -4,6 +4,7 @@ import pandas as pd import pytest import xarray as xr +import xarray.testing as xrt from shapely import MultiPoint, Point, from_wkt from xpublish_edr.geometry.area import select_by_area @@ -151,9 +152,7 @@ def test_select_position_projected_xy(projected_xy_dataset): npt.assert_approx_equal(projected_point.y, 21.725), "Latitude is incorrect" ds = select_by_position(projected_xy_dataset, projected_point) - npt.assert_approx_equal(ds.rlon.values, 18.045), "Longitude is incorrect" - npt.assert_approx_equal(ds.rlat.values, 21.725), "Latitude is incorrect" - npt.assert_approx_equal(ds.temp.values, 0.89959461), "Temperature is incorrect" + xrt.assert_equal(ds, projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest")) def test_select_position_regular_xy_interpolate(regular_xy_dataset): From b022ebfbb06eeef81660b59ef9b25771e36f3290 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 11:00:14 -0500 Subject: [PATCH 06/21] Format --- tests/test_select.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/test_select.py b/tests/test_select.py index 5043522..2b95beb 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -152,7 +152,10 @@ def test_select_position_projected_xy(projected_xy_dataset): npt.assert_approx_equal(projected_point.y, 21.725), "Latitude is incorrect" ds = select_by_position(projected_xy_dataset, projected_point) - xrt.assert_equal(ds, projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest")) + xrt.assert_equal( + ds, + projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest"), + ) def test_select_position_regular_xy_interpolate(regular_xy_dataset): From 6f004dd53c56c981465b14807ae98f34125adcd6 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 14:14:01 -0500 Subject: [PATCH 07/21] Convert proj to queries crs --- tests/test_select.py | 21 ++++++++- xpublish_edr/geometry/common.py | 82 ++++++++++++++++++++++++++++++--- 2 files changed, 95 insertions(+), 8 deletions(-) diff --git a/tests/test_select.py b/tests/test_select.py index 2b95beb..530d59f 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -144,7 +144,7 @@ def test_select_position_regular_xy(regular_xy_dataset): def test_select_position_projected_xy(projected_xy_dataset): - from xpublish_edr.geometry.common import project_geometry + from xpublish_edr.geometry.common import project_geometry, project_dataset point = Point((64.59063409, 66.66454929)) projected_point = project_geometry(projected_xy_dataset, "EPSG:4326", point) @@ -157,6 +157,25 @@ def test_select_position_projected_xy(projected_xy_dataset): projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest"), ) + projected_ds = project_dataset(ds, "EPSG:4326") + ( + npt.assert_approx_equal(projected_ds.longitude.values, 64.59063409), + "Longitude is incorrect", + ) + ( + npt.assert_approx_equal(projected_ds.latitude.values, 66.66454929), + "Latitude is incorrect", + ) + ( + npt.assert_approx_equal( + projected_ds.temp.values, + projected_xy_dataset.sel( + rlon=[18.045], rlat=[21.725], method="nearest" + ).temp.values, + ), + "Temperature is incorrect", + ) + def test_select_position_regular_xy_interpolate(regular_xy_dataset): point = Point((204, 44)) diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index a339617..5c5c745 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -30,20 +30,24 @@ def is_regular_xy_coords(ds: xr.Dataset) -> bool: return coord_is_regular(ds.cf["X"]) and coord_is_regular(ds.cf["Y"]) -def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> Geometry: - """ - Get the projection from the dataset - """ +def dataset_crs(ds: xr.Dataset) -> pyproj.CRS: grid_mapping_names = ds.cf.grid_mapping_names if len(grid_mapping_names) == 0: - # TODO: Should we require a grid mapping? For now return as is - return geometry + # Default to WGS84 + return pyproj.crs.CRS.from_epsg(4326) if len(grid_mapping_names) > 1: raise ValueError(f"Multiple grid mappings found: {grid_mapping_names!r}!") (grid_mapping_var,) = tuple(itertools.chain(*ds.cf.grid_mapping_names.values())) grid_mapping = ds[grid_mapping_var] - data_crs = pyproj.crs.CRS.from_cf(grid_mapping.attrs) + return pyproj.CRS.from_cf(grid_mapping.attrs) + + +def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> Geometry: + """ + Get the projection from the dataset + """ + data_crs = dataset_crs(ds) transformer = transformer_from_crs( crs_from=geometry_crs, @@ -51,3 +55,67 @@ def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> G always_xy=True, ) return transform(transformer.transform, geometry) + + +def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: + """ + Project the dataset to the given CRS + """ + data_crs = dataset_crs(ds) + target_crs = pyproj.CRS.from_string(query_crs) + if data_crs == target_crs: + return ds + + transformer = transformer_from_crs( + crs_from=data_crs, + crs_to=target_crs, + always_xy=True, + ) + + # TODO: Handle rotated pole + cf_coords = target_crs.coordinate_system.to_cf() + + # Get the new X and Y coordinates + target_y_coord = next(coord for coord in cf_coords if coord["axis"] == "Y") + target_x_coord = next(coord for coord in cf_coords if coord["axis"] == "X") + + # Transform the coordinates + # If the data is vectorized, we just transform the points in full + # TODO: Handle 2D coordinates + if not is_regular_xy_coords(ds): + raise NotImplementedError("Only 1D coordinates are supported") + + x_dim = ds.cf["X"].dims[0] + y_dim = ds.cf["Y"].dims[0] + if x_dim == [VECTORIZED_DIM]: + x = ds.cf["X"] + y = ds.cf["Y"] + else: + # Otherwise we need to transform the full grid + x, y = xr.broadcast(ds.cf["X"], ds.cf["Y"]) + + x, y = transformer.transform(x, y) + + coords_to_drop = [ + c for c in ds.coords if x_dim in ds[c].dims or y_dim in ds[c].dims + ] + + target_x_coord_name = target_x_coord["standard_name"] + target_y_coord_name = target_y_coord["standard_name"] + + if target_x_coord_name in ds: + target_x_coord_name += "_" + if target_y_coord_name in ds: + target_y_coord_name += "_" + + # Create the new dataset with vectorized coordinates + ds = ds.assign_coords( + { + target_x_coord_name: ((x_dim, y_dim), x), + target_y_coord_name: ((x_dim, y_dim), y), + } + ) + + ds = ds.drop(coords_to_drop) + + return ds From 143d1f740375911eb7815a989113c93fd7926882 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 Nov 2024 19:14:34 +0000 Subject: [PATCH 08/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_select.py | 4 ++-- xpublish_edr/geometry/common.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/test_select.py b/tests/test_select.py index 530d59f..56b0123 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -144,7 +144,7 @@ def test_select_position_regular_xy(regular_xy_dataset): def test_select_position_projected_xy(projected_xy_dataset): - from xpublish_edr.geometry.common import project_geometry, project_dataset + from xpublish_edr.geometry.common import project_dataset, project_geometry point = Point((64.59063409, 66.66454929)) projected_point = project_geometry(projected_xy_dataset, "EPSG:4326", point) @@ -170,7 +170,7 @@ def test_select_position_projected_xy(projected_xy_dataset): npt.assert_approx_equal( projected_ds.temp.values, projected_xy_dataset.sel( - rlon=[18.045], rlat=[21.725], method="nearest" + rlon=[18.045], rlat=[21.725], method="nearest", ).temp.values, ), "Temperature is incorrect", diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index 5c5c745..272fd11 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -84,7 +84,7 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: # TODO: Handle 2D coordinates if not is_regular_xy_coords(ds): raise NotImplementedError("Only 1D coordinates are supported") - + x_dim = ds.cf["X"].dims[0] y_dim = ds.cf["Y"].dims[0] if x_dim == [VECTORIZED_DIM]: @@ -113,7 +113,7 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: { target_x_coord_name: ((x_dim, y_dim), x), target_y_coord_name: ((x_dim, y_dim), y), - } + }, ) ds = ds.drop(coords_to_drop) From a8cb79f5358b2d494fe2bece4325c1ec79d2bf42 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 14:15:43 -0500 Subject: [PATCH 09/21] format --- tests/test_select.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_select.py b/tests/test_select.py index 56b0123..fe49b5e 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -170,7 +170,9 @@ def test_select_position_projected_xy(projected_xy_dataset): npt.assert_approx_equal( projected_ds.temp.values, projected_xy_dataset.sel( - rlon=[18.045], rlat=[21.725], method="nearest", + rlon=[18.045], + rlat=[21.725], + method="nearest", ).temp.values, ), "Temperature is incorrect", From 94cb0912a2dbdd1ef0c6d9c41484154fca53e99b Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 15:20:55 -0500 Subject: [PATCH 10/21] Selection done, with some hackiness --- tests/test_select.py | 38 +++++++++++++++++++++++++++---- xpublish_edr/geometry/common.py | 32 ++++++++++++++++++-------- xpublish_edr/geometry/position.py | 6 +++-- 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/tests/test_select.py b/tests/test_select.py index fe49b5e..3adda6c 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -8,6 +8,7 @@ from shapely import MultiPoint, Point, from_wkt from xpublish_edr.geometry.area import select_by_area +from xpublish_edr.geometry.common import project_dataset from xpublish_edr.geometry.position import select_by_position from xpublish_edr.query import EDRQuery @@ -144,10 +145,12 @@ def test_select_position_regular_xy(regular_xy_dataset): def test_select_position_projected_xy(projected_xy_dataset): - from xpublish_edr.geometry.common import project_dataset, project_geometry + query = EDRQuery( + coords="POINT(64.59063409 66.66454929)", + crs="EPSG:4326", + ) - point = Point((64.59063409, 66.66454929)) - projected_point = project_geometry(projected_xy_dataset, "EPSG:4326", point) + projected_point = query.project_geometry(projected_xy_dataset) npt.assert_approx_equal(projected_point.x, 18.045), "Longitude is incorrect" npt.assert_approx_equal(projected_point.y, 21.725), "Latitude is incorrect" @@ -157,7 +160,7 @@ def test_select_position_projected_xy(projected_xy_dataset): projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest"), ) - projected_ds = project_dataset(ds, "EPSG:4326") + projected_ds = project_dataset(ds, query.crs) ( npt.assert_approx_equal(projected_ds.longitude.values, 64.59063409), "Longitude is incorrect", @@ -215,6 +218,33 @@ def test_select_position_regular_xy_multi(regular_xy_dataset): ) +def test_select_position_projected_xy_multi(projected_xy_dataset): + query = EDRQuery( + coords="MULTIPOINT(64.3 66.6, 64.6 66.5)", + crs="EPSG:4326", + method="linear", + ) + + projected_points = query.project_geometry(projected_xy_dataset) + ds = select_by_position(projected_xy_dataset, projected_points, method="linear") + projected_ds = project_dataset(ds, query.crs) + assert "temp" in projected_ds, "Dataset does not contain the temp variable" + assert "rlon" not in projected_ds, "Dataset does not contain the rlon variable" + assert "rlat" not in projected_ds, "Dataset does not contain the rlat variable" + ( + npt.assert_array_almost_equal(projected_ds.longitude, [64.3, 64.6]), + "Longitude is incorrect", + ) + ( + npt.assert_array_almost_equal(projected_ds.latitude, [66.6, 66.5]), + "Latitude is incorrect", + ) + npt.assert_array_almost_equal( + ds.temp, + projected_ds.temp, + ), "Temperature is incorrect" + + def test_select_position_regular_xy_multi_interpolate(regular_xy_dataset): points = MultiPoint([(202, 45), (205, 48)]) ds = select_by_position(regular_xy_dataset, points, method="linear") diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index 272fd11..3961b05 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -79,20 +79,32 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: target_y_coord = next(coord for coord in cf_coords if coord["axis"] == "Y") target_x_coord = next(coord for coord in cf_coords if coord["axis"] == "X") + X = ds.cf["X"] + Y = ds.cf["Y"] + # Transform the coordinates # If the data is vectorized, we just transform the points in full # TODO: Handle 2D coordinates - if not is_regular_xy_coords(ds): + if len(X.dims) > 1 or len(Y.dims) > 1: raise NotImplementedError("Only 1D coordinates are supported") - x_dim = ds.cf["X"].dims[0] - y_dim = ds.cf["Y"].dims[0] - if x_dim == [VECTORIZED_DIM]: - x = ds.cf["X"] - y = ds.cf["Y"] + x_dim = X.dims[0] + y_dim = Y.dims[0] + if x_dim == VECTORIZED_DIM and y_dim == VECTORIZED_DIM: + x = X + y = Y + target_dims: tuple = (VECTORIZED_DIM,) else: # Otherwise we need to transform the full grid - x, y = xr.broadcast(ds.cf["X"], ds.cf["Y"]) + # TODO: Is there a better way to handle this? this is quite hacky + var = [d for d in ds.data_vars if x_dim and y_dim in ds[d].dims][0] + var_dims = ds[var].dims + if var_dims.index(x_dim) < var_dims.index(y_dim): + x, y = xr.broadcast(X, Y) + target_dims = (x_dim, y_dim) + else: + x, y = xr.broadcast(Y, X) + target_dims = (y_dim, x_dim) x, y = transformer.transform(x, y) @@ -111,11 +123,11 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: # Create the new dataset with vectorized coordinates ds = ds.assign_coords( { - target_x_coord_name: ((x_dim, y_dim), x), - target_y_coord_name: ((x_dim, y_dim), y), + target_x_coord_name: (target_dims, x), + target_y_coord_name: (target_dims, y), }, ) - ds = ds.drop(coords_to_drop) + ds = ds.drop_vars(coords_to_drop) return ds diff --git a/xpublish_edr/geometry/position.py b/xpublish_edr/geometry/position.py index 671409b..f995e06 100644 --- a/xpublish_edr/geometry/position.py +++ b/xpublish_edr/geometry/position.py @@ -60,8 +60,10 @@ def _select_by_multiple_positions_regular_xy_grid( """ # Find the nearest X and Y coordinates to the point using vectorized indexing x, y = np.array(list(zip(*[(point.x, point.y) for point in points.geoms]))) - sel_x = xr.Variable(data=x, dims=VECTORIZED_DIM) - sel_y = xr.Variable(data=y, dims=VECTORIZED_DIM) + + # When using vectorized indexing with interp, we need to persist the attributes explicitly + sel_x = xr.Variable(data=x, dims=VECTORIZED_DIM, attrs=ds.cf["X"].attrs) + sel_y = xr.Variable(data=y, dims=VECTORIZED_DIM, attrs=ds.cf["Y"].attrs) if method == "nearest": return ds.cf.sel(X=sel_x, Y=sel_y, method=method) else: From 09e700018aacc566faf6c54d3b5975afc31c9924 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 15:23:49 -0500 Subject: [PATCH 11/21] Add projection to plugin --- xpublish_edr/plugin.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/xpublish_edr/plugin.py b/xpublish_edr/plugin.py index c7866f1..fc8fd71 100644 --- a/xpublish_edr/plugin.py +++ b/xpublish_edr/plugin.py @@ -11,6 +11,7 @@ from xpublish_edr.formats.to_covjson import to_cf_covjson from xpublish_edr.geometry.area import select_by_area +from xpublish_edr.geometry.common import project_dataset from xpublish_edr.geometry.position import select_by_position from xpublish_edr.logger import logger from xpublish_edr.query import EDRQuery, edr_query @@ -112,6 +113,17 @@ def get_position( f"Dataset filtered by position ({query.geometry}): {ds}", ) + try: + ds = project_dataset(ds, query.crs) + except Exception as e: + logger.error(f"Error projecting dataset: {e}") + raise HTTPException( + status_code=404, + detail="Error projecting dataset", + ) + + logger.debug(f"Dataset projected to {query.crs}: {ds}") + if query.format: try: format_fn = position_formats()[query.format] @@ -157,6 +169,17 @@ def get_area( logger.debug(f"Dataset filtered by polygon {query.geometry.boundary}: {ds}") + try: + ds = project_dataset(ds, query.crs) + except Exception as e: + logger.error(f"Error projecting dataset: {e}") + raise HTTPException( + status_code=404, + detail="Error projecting dataset", + ) + + logger.debug(f"Dataset projected to {query.crs}: {ds}") + if query.format: try: format_fn = position_formats()[query.format] From 2fe2e98803c1db1ab5f7fe4e266b93ea8063d117 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Thu, 14 Nov 2024 15:59:23 -0500 Subject: [PATCH 12/21] Fix area subsetting, add tests for pluign --- tests/test_select.py | 33 +++++++++++++++++++++++++++++---- xpublish_edr/geometry/area.py | 14 ++++++++++---- 2 files changed, 39 insertions(+), 8 deletions(-) diff --git a/tests/test_select.py b/tests/test_select.py index 3adda6c..508bcb4 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -239,10 +239,13 @@ def test_select_position_projected_xy_multi(projected_xy_dataset): npt.assert_array_almost_equal(projected_ds.latitude, [66.6, 66.5]), "Latitude is incorrect", ) - npt.assert_array_almost_equal( - ds.temp, - projected_ds.temp, - ), "Temperature is incorrect" + ( + npt.assert_array_almost_equal( + ds.temp, + projected_ds.temp, + ), + "Temperature is incorrect", + ) def test_select_position_regular_xy_multi_interpolate(regular_xy_dataset): @@ -311,6 +314,28 @@ def test_select_area_regular_xy(regular_xy_dataset): ) +def test_select_area_projected_xy(projected_xy_dataset): + query = EDRQuery( + coords="POLYGON((64.3 66.82, 64.5 66.82, 64.5 66.6, 64.3 66.6, 64.3 66.82))", + crs="EPSG:4326", + ) + + projected_area = query.project_geometry(projected_xy_dataset) + ds = select_by_area(projected_xy_dataset, projected_area) + projected_ds = project_dataset(ds, query.crs) + + assert projected_ds is not None, "Dataset was not returned" + assert "temp" in projected_ds, "Dataset does not contain the air variable" + assert "latitude" in projected_ds, "Dataset does not contain the latitude variable" + assert ( + "longitude" in projected_ds + ), "Dataset does not contain the longitude variable" + + assert projected_ds.longitude.shape[0] == 1, "Longitude shape is incorrect" + assert projected_ds.latitude.shape[0] == 1, "Latitude shape is incorrect" + assert projected_ds.temp.shape[0] == 1, "Temperature shape is incorrect" + + def test_select_area_regular_xy_boundary(regular_xy_dataset): polygon = from_wkt("POLYGON((200 40, 200 50, 210 50, 210 40, 200 40))").buffer( 0.0001, diff --git a/xpublish_edr/geometry/area.py b/xpublish_edr/geometry/area.py index dd37d5a..7149736 100644 --- a/xpublish_edr/geometry/area.py +++ b/xpublish_edr/geometry/area.py @@ -31,7 +31,15 @@ def _select_area_regular_xy_grid( """ # To minimize performance impact, we first subset the dataset to the bounding box of the polygon (minx, miny, maxx, maxy) = polygon.bounds - ds = ds.cf.sel(X=slice(minx, maxx), Y=slice(maxy, miny)) + if ds.cf.indexes["X"].is_monotonic_increasing: + x_sel = slice(minx, maxx) + else: + x_sel = slice(maxx, minx) + if ds.cf.indexes["Y"].is_monotonic_increasing: + y_sel = slice(miny, maxy) + else: + y_sel = slice(maxy, miny) + ds = ds.cf.sel(X=x_sel, Y=y_sel) # For a regular grid, we can create a meshgrid of the X and Y coordinates to create a spatial mask pts = np.meshgrid(ds.cf["X"], ds.cf["Y"]) @@ -45,6 +53,4 @@ def _select_area_regular_xy_grid( y_sel = xr.Variable(data=y_inds, dims=VECTORIZED_DIM) # Apply the mask and vectorize to a 1d collection of points - ds_sub = ds.cf.isel(X=x_sel, Y=y_sel) - - return ds_sub + return ds.cf.isel(X=x_sel, Y=y_sel) From a852e9e214da7dd87b37e5b55ed54dc8a41cbc75 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 14:25:51 -0500 Subject: [PATCH 13/21] Add test to plugin --- tests/test_cf_router.py | 73 ++++++++++++++++++++++++++++++++--------- 1 file changed, 57 insertions(+), 16 deletions(-) diff --git a/tests/test_cf_router.py b/tests/test_cf_router.py index 756d2cd..d903e94 100644 --- a/tests/test_cf_router.py +++ b/tests/test_cf_router.py @@ -7,15 +7,25 @@ @pytest.fixture(scope="session") -def cf_dataset(): +def cf_air_dataset(): from cf_xarray.datasets import airds return airds @pytest.fixture(scope="session") -def cf_xpublish(cf_dataset): - rest = xpublish.Rest({"air": cf_dataset}, plugins={"edr": CfEdrPlugin()}) +def cf_temp_dataset(): + from cf_xarray.datasets import rotds + + return rotds + + +@pytest.fixture(scope="session") +def cf_xpublish(cf_air_dataset, cf_temp_dataset): + rest = xpublish.Rest( + {"air": cf_air_dataset, "temp": cf_temp_dataset}, + plugins={"edr": CfEdrPlugin()}, + ) return rest @@ -52,7 +62,7 @@ def test_cf_area_formats(cf_client): assert "csv" in data, "csv is not a valid format" -def test_cf_position_query(cf_client, cf_dataset): +def test_cf_position_query(cf_client, cf_air_dataset, cf_temp_dataset): x = 204 y = 44 response = cf_client.get(f"/datasets/air/edr/position?coords=POINT({x} {y})") @@ -76,17 +86,18 @@ def test_cf_position_query(cf_client, cf_dataset): air_param = data["parameters"]["air"] assert ( - air_param["unit"]["label"]["en"] == cf_dataset["air"].attrs["units"] + air_param["unit"]["label"]["en"] == cf_air_dataset["air"].attrs["units"] ), "DataArray units should be set as parameter units" assert ( - air_param["observedProperty"]["id"] == cf_dataset["air"].attrs["standard_name"] + air_param["observedProperty"]["id"] + == cf_air_dataset["air"].attrs["standard_name"] ), "DataArray standard_name should be set as the observed property id" assert ( air_param["observedProperty"]["label"]["en"] - == cf_dataset["air"].attrs["long_name"] + == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter observed property" assert ( - air_param["description"]["en"] == cf_dataset["air"].attrs["long_name"] + air_param["description"]["en"] == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter description" air_range = data["ranges"]["air"] @@ -99,6 +110,35 @@ def test_cf_position_query(cf_client, cf_dataset): len(air_range["values"]) == 4 ), "There should be 4 values, one for each time step" + # Test with a dataset containing data in a different coordinate system + x = 64.59063409 + y = 66.66454929 + response = cf_client.get(f"/datasets/temp/edr/position?coords=POINT({x} {y})") + assert response.status_code == 200, "Response did not return successfully" + + data = response.json() + print(data) + for key in ("type", "domain", "parameters", "ranges"): + assert key in data, f"Key {key} is not a top level key in the CovJSON response" + + axes = data["domain"]["axes"] + + npt.assert_array_almost_equal( + axes["longitude"]["values"], + [[64.59063409]], + ), "Did not select nearby x coordinate" + npt.assert_array_almost_equal( + axes["latitude"]["values"], + [[66.66454929]], + ), "Did not select a nearby y coordinate" + + temp_range = data["ranges"]["temp"] + assert temp_range["type"] == "NdArray", "Response range should be a NdArray" + assert temp_range["dataType"] == "float", "Air dataType should be floats" + assert temp_range["axisNames"] == ["rlat", "rlon"], "All dimensions should persist" + assert temp_range["shape"] == [1, 1], "The shape of the array should be 1x1" + assert len(temp_range["values"]) == 1, "There should be 1 value selected" + def test_cf_position_csv(cf_client): x = 204 @@ -346,7 +386,7 @@ def test_cf_multiple_position_csv(cf_client): assert key in csv_data[0], f"column {key} should be in the header" -def test_cf_area_query(cf_client, cf_dataset): +def test_cf_area_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=cf_covjson") @@ -372,17 +412,18 @@ def test_cf_area_query(cf_client, cf_dataset): air_param = data["parameters"]["air"] assert ( - air_param["unit"]["label"]["en"] == cf_dataset["air"].attrs["units"] + air_param["unit"]["label"]["en"] == cf_air_dataset["air"].attrs["units"] ), "DataArray units should be set as parameter units" assert ( - air_param["observedProperty"]["id"] == cf_dataset["air"].attrs["standard_name"] + air_param["observedProperty"]["id"] + == cf_air_dataset["air"].attrs["standard_name"] ), "DataArray standard_name should be set as the observed property id" assert ( air_param["observedProperty"]["label"]["en"] - == cf_dataset["air"].attrs["long_name"] + == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter observed property" assert ( - air_param["description"]["en"] == cf_dataset["air"].attrs["long_name"] + air_param["description"]["en"] == cf_air_dataset["air"].attrs["long_name"] ), "DataArray long_name should be set as parameter description" air_range = data["ranges"]["air"] @@ -403,7 +444,7 @@ def test_cf_area_query(cf_client, cf_dataset): ), "There should be 26 values, 9 for each time step" -def test_cf_area_csv_query(cf_client, cf_dataset): +def test_cf_area_csv_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=csv") @@ -427,7 +468,7 @@ def test_cf_area_csv_query(cf_client, cf_dataset): assert key in csv_data[0], f"column {key} should be in the header" -def test_cf_area_geojson_query(cf_client, cf_dataset): +def test_cf_area_geojson_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=geojson") @@ -446,7 +487,7 @@ def test_cf_area_geojson_query(cf_client, cf_dataset): assert len(features) == 36, "There should be 36 data points" -def test_cf_area_nc_query(cf_client, cf_dataset): +def test_cf_area_nc_query(cf_client, cf_air_dataset): coords = "POLYGON((201 41, 201 49, 209 49, 209 41, 201 41))" response = cf_client.get(f"/datasets/air/edr/area?coords={coords}&f=nc") From afd7d453331c35dc6af54420d081091c605ea952 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 16:31:06 -0500 Subject: [PATCH 14/21] Remove print Co-authored-by: Deepak Cherian --- tests/test_cf_router.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_cf_router.py b/tests/test_cf_router.py index d903e94..6afc671 100644 --- a/tests/test_cf_router.py +++ b/tests/test_cf_router.py @@ -117,7 +117,6 @@ def test_cf_position_query(cf_client, cf_air_dataset, cf_temp_dataset): assert response.status_code == 200, "Response did not return successfully" data = response.json() - print(data) for key in ("type", "domain", "parameters", "ranges"): assert key in data, f"Key {key} is not a top level key in the CovJSON response" From 4c70243c7fb221c88526c455489e7a5dbc13bb24 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 16:31:32 -0500 Subject: [PATCH 15/21] Update tests/test_cf_router.py Co-authored-by: Deepak Cherian --- tests/test_cf_router.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_cf_router.py b/tests/test_cf_router.py index 6afc671..a7a8041 100644 --- a/tests/test_cf_router.py +++ b/tests/test_cf_router.py @@ -124,11 +124,11 @@ def test_cf_position_query(cf_client, cf_air_dataset, cf_temp_dataset): npt.assert_array_almost_equal( axes["longitude"]["values"], - [[64.59063409]], + [[x]], ), "Did not select nearby x coordinate" npt.assert_array_almost_equal( axes["latitude"]["values"], - [[66.66454929]], + [[y]], ), "Did not select a nearby y coordinate" temp_range = data["ranges"]["temp"] From 526472bf986a987444c1960df7de4934e715afaa Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 16:31:50 -0500 Subject: [PATCH 16/21] Update tests/test_select.py Co-authored-by: Deepak Cherian --- tests/test_select.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_select.py b/tests/test_select.py index 508bcb4..e54f41d 100644 --- a/tests/test_select.py +++ b/tests/test_select.py @@ -155,7 +155,7 @@ def test_select_position_projected_xy(projected_xy_dataset): npt.assert_approx_equal(projected_point.y, 21.725), "Latitude is incorrect" ds = select_by_position(projected_xy_dataset, projected_point) - xrt.assert_equal( + xrt.assert_identical( ds, projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest"), ) From 76df163f0a914141ceaf34841ed536c6687904ba Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 16:32:17 -0500 Subject: [PATCH 17/21] Update xpublish_edr/geometry/area.py Co-authored-by: Deepak Cherian --- xpublish_edr/geometry/area.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/xpublish_edr/geometry/area.py b/xpublish_edr/geometry/area.py index 7149736..d2a0a78 100644 --- a/xpublish_edr/geometry/area.py +++ b/xpublish_edr/geometry/area.py @@ -31,11 +31,12 @@ def _select_area_regular_xy_grid( """ # To minimize performance impact, we first subset the dataset to the bounding box of the polygon (minx, miny, maxx, maxy) = polygon.bounds - if ds.cf.indexes["X"].is_monotonic_increasing: + indexes = ds.cf.indexes + if indexes["X"].is_monotonic_increasing: x_sel = slice(minx, maxx) else: x_sel = slice(maxx, minx) - if ds.cf.indexes["Y"].is_monotonic_increasing: + if indexes["Y"].is_monotonic_increasing: y_sel = slice(miny, maxy) else: y_sel = slice(maxy, miny) From dffeffe5f4a3e1deac9c14cc2fb2e1b6baea3200 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 16:43:19 -0500 Subject: [PATCH 18/21] Fix coord drops --- xpublish_edr/geometry/common.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index 3961b05..d10e027 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -115,10 +115,14 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: target_x_coord_name = target_x_coord["standard_name"] target_y_coord_name = target_y_coord["standard_name"] - if target_x_coord_name in ds: - target_x_coord_name += "_" - if target_y_coord_name in ds: - target_y_coord_name += "_" + stdnames = ds.cf.standard_names + coords_to_drop += list( + itertools.chain( + stdnames.get(target_x_coord_name, []), + stdnames.get(target_y_coord_name, []), + ), + ) + ds = ds.drop_vars(coords_to_drop) # Create the new dataset with vectorized coordinates ds = ds.assign_coords( @@ -128,6 +132,4 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: }, ) - ds = ds.drop_vars(coords_to_drop) - return ds From ee1c4debeef6fcdd8e63806f1e7121b2ef4527e8 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 16:45:27 -0500 Subject: [PATCH 19/21] Require lat and lng for projecting --- xpublish_edr/geometry/common.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index d10e027..a57ce10 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -34,7 +34,11 @@ def dataset_crs(ds: xr.Dataset) -> pyproj.CRS: grid_mapping_names = ds.cf.grid_mapping_names if len(grid_mapping_names) == 0: # Default to WGS84 - return pyproj.crs.CRS.from_epsg(4326) + keys = ds.cf.keys() + if "latitude" in keys and "longitude" in keys: + return pyproj.CRS.from_epsg(4326) + else: + raise ValueError("Unknown coordinate system") if len(grid_mapping_names) > 1: raise ValueError(f"Multiple grid mappings found: {grid_mapping_names!r}!") (grid_mapping_var,) = tuple(itertools.chain(*ds.cf.grid_mapping_names.values())) From a5a7e8880a508cc51108b0eecaca0dcc5de5c48f Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 16:50:38 -0500 Subject: [PATCH 20/21] Cleaner broadcast op --- xpublish_edr/geometry/common.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index a57ce10..24be958 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -92,26 +92,14 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: if len(X.dims) > 1 or len(Y.dims) > 1: raise NotImplementedError("Only 1D coordinates are supported") - x_dim = X.dims[0] - y_dim = Y.dims[0] - if x_dim == VECTORIZED_DIM and y_dim == VECTORIZED_DIM: - x = X - y = Y - target_dims: tuple = (VECTORIZED_DIM,) - else: - # Otherwise we need to transform the full grid - # TODO: Is there a better way to handle this? this is quite hacky - var = [d for d in ds.data_vars if x_dim and y_dim in ds[d].dims][0] - var_dims = ds[var].dims - if var_dims.index(x_dim) < var_dims.index(y_dim): - x, y = xr.broadcast(X, Y) - target_dims = (x_dim, y_dim) - else: - x, y = xr.broadcast(Y, X) - target_dims = (y_dim, x_dim) + x, y = xr.broadcast(X, Y) + target_dims = x.dims x, y = transformer.transform(x, y) + x_dim = X.dims[0] + y_dim = Y.dims[0] + coords_to_drop = [ c for c in ds.coords if x_dim in ds[c].dims or y_dim in ds[c].dims ] From 501417bdbb7d94d3f25b1f9134bbeeb1ff0497b8 Mon Sep 17 00:00:00 2001 From: Matthew Iannucci Date: Fri, 15 Nov 2024 17:03:10 -0500 Subject: [PATCH 21/21] Explicit transpose to cleanup dataset on projection --- xpublish_edr/geometry/common.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xpublish_edr/geometry/common.py b/xpublish_edr/geometry/common.py index 24be958..16368c9 100644 --- a/xpublish_edr/geometry/common.py +++ b/xpublish_edr/geometry/common.py @@ -124,4 +124,7 @@ def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset: }, ) + if x_dim != y_dim: + ds = ds.transpose(..., y_dim, x_dim) + return ds