Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add CRS support for queries #58

Merged
merged 22 commits into from
Nov 15, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ In the future, when `xpublish` supports [`DataTree`](https://docs.xarray.dev/en/
| `z` | ✅ | |
| `datetime` | ✅ | |
| `parameter-name` | ✅ | |
| `crs` | ❌ | Not currently supported, all coordinates should be in the reference system of the queried dataset |
| `crs` | ✅ | Requires a CF compliant [grid mapping](https://cf-xarray.readthedocs.io/en/latest/grid_mappings.html) on the target dataset. Default is `EPSG:4326` |
| `parameter-name` | ✅ | |
| `f` | ✅ | |
| `method` | ➕ | Optional: controls data selection. Use "nearest" for nearest neighbor selection, or "linear" for interpolated selection. Uses `nearest` if not specified |
Expand All @@ -84,7 +84,7 @@ In the future, when `xpublish` supports [`DataTree`](https://docs.xarray.dev/en/
| `z` | ✅ | |
| `datetime` | ✅ | |
| `parameter-name` | ✅ | |
| `crs` | ❌ | Not currently supported, all coordinates should be in the reference system of the queried dataset |
| `crs` | ✅ | Requires a CF compliant [grid mapping](https://cf-xarray.readthedocs.io/en/latest/grid_mappings.html) on the target dataset. Default is `EPSG:4326` |
| `parameter-name` | ✅ | |
| `f` | ✅ | |
| `method` | ➕ | Optional: controls data selection. Use "nearest" for nearest neighbor selection, or "linear" for interpolated selection. Uses `nearest` if not specified |
Expand Down
75 changes: 75 additions & 0 deletions tests/test_select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,11 @@
import pandas as pd
import pytest
import xarray as xr
import xarray.testing as xrt
from shapely import MultiPoint, Point, from_wkt

from xpublish_edr.geometry.area import select_by_area
from xpublish_edr.geometry.common import project_dataset
from xpublish_edr.geometry.position import select_by_position
from xpublish_edr.query import EDRQuery

Expand All @@ -17,6 +19,14 @@ def regular_xy_dataset():
return xr.tutorial.load_dataset("air_temperature")


@pytest.fixture(scope="function")
def projected_xy_dataset():
"""Loads a sample dataset with projected X and Y coordinates"""
from cf_xarray.datasets import rotds

return rotds


def test_select_query(regular_xy_dataset):
query = EDRQuery(
coords="POINT(200 45)",
Expand Down Expand Up @@ -134,6 +144,44 @@ def test_select_position_regular_xy(regular_xy_dataset):
npt.assert_approx_equal(ds["air"][-1], 279.19), "Temperature is incorrect"


def test_select_position_projected_xy(projected_xy_dataset):
query = EDRQuery(
coords="POINT(64.59063409 66.66454929)",
crs="EPSG:4326",
)

projected_point = query.project_geometry(projected_xy_dataset)
npt.assert_approx_equal(projected_point.x, 18.045), "Longitude is incorrect"
npt.assert_approx_equal(projected_point.y, 21.725), "Latitude is incorrect"

ds = select_by_position(projected_xy_dataset, projected_point)
xrt.assert_equal(
mpiannucci marked this conversation as resolved.
Show resolved Hide resolved
ds,
projected_xy_dataset.sel(rlon=[18.045], rlat=[21.725], method="nearest"),
)

projected_ds = project_dataset(ds, query.crs)
(
npt.assert_approx_equal(projected_ds.longitude.values, 64.59063409),
"Longitude is incorrect",
)
(
npt.assert_approx_equal(projected_ds.latitude.values, 66.66454929),
"Latitude is incorrect",
)
(
npt.assert_approx_equal(
projected_ds.temp.values,
projected_xy_dataset.sel(
rlon=[18.045],
rlat=[21.725],
method="nearest",
).temp.values,
),
"Temperature is incorrect",
)


def test_select_position_regular_xy_interpolate(regular_xy_dataset):
point = Point((204, 44))
ds = select_by_position(regular_xy_dataset, point, method="linear")
Expand Down Expand Up @@ -170,6 +218,33 @@ def test_select_position_regular_xy_multi(regular_xy_dataset):
)


def test_select_position_projected_xy_multi(projected_xy_dataset):
query = EDRQuery(
coords="MULTIPOINT(64.3 66.6, 64.6 66.5)",
crs="EPSG:4326",
method="linear",
)

projected_points = query.project_geometry(projected_xy_dataset)
ds = select_by_position(projected_xy_dataset, projected_points, method="linear")
projected_ds = project_dataset(ds, query.crs)
assert "temp" in projected_ds, "Dataset does not contain the temp variable"
assert "rlon" not in projected_ds, "Dataset does not contain the rlon variable"
assert "rlat" not in projected_ds, "Dataset does not contain the rlat variable"
(
npt.assert_array_almost_equal(projected_ds.longitude, [64.3, 64.6]),
"Longitude is incorrect",
)
(
npt.assert_array_almost_equal(projected_ds.latitude, [66.6, 66.5]),
"Latitude is incorrect",
)
npt.assert_array_almost_equal(
ds.temp,
projected_ds.temp,
), "Temperature is incorrect"


def test_select_position_regular_xy_multi_interpolate(regular_xy_dataset):
points = MultiPoint([(202, 45), (205, 48)])
ds = select_by_position(regular_xy_dataset, points, method="linear")
Expand Down
112 changes: 112 additions & 0 deletions xpublish_edr/geometry/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,19 @@
Common geometry handling functions
"""

import itertools
from functools import lru_cache

import pyproj
import xarray as xr
from shapely import Geometry
from shapely.ops import transform

VECTORIZED_DIM = "pts"

# https://pyproj4.github.io/pyproj/stable/advanced_examples.html#caching-pyproj-objectshttps://pyproj4.github.io/pyproj/stable/advanced_examples.html#caching-pyproj-objects
transformer_from_crs = lru_cache(pyproj.Transformer.from_crs)


def coord_is_regular(da: xr.DataArray) -> bool:
"""
Expand All @@ -19,3 +28,106 @@ def is_regular_xy_coords(ds: xr.Dataset) -> bool:
Check if the dataset has 2D coordinates
"""
return coord_is_regular(ds.cf["X"]) and coord_is_regular(ds.cf["Y"])


def dataset_crs(ds: xr.Dataset) -> pyproj.CRS:
grid_mapping_names = ds.cf.grid_mapping_names
if len(grid_mapping_names) == 0:
mpiannucci marked this conversation as resolved.
Show resolved Hide resolved
# Default to WGS84
return pyproj.crs.CRS.from_epsg(4326)
if len(grid_mapping_names) > 1:
raise ValueError(f"Multiple grid mappings found: {grid_mapping_names!r}!")
(grid_mapping_var,) = tuple(itertools.chain(*ds.cf.grid_mapping_names.values()))

grid_mapping = ds[grid_mapping_var]
return pyproj.CRS.from_cf(grid_mapping.attrs)


def project_geometry(ds: xr.Dataset, geometry_crs: str, geometry: Geometry) -> Geometry:
"""
Get the projection from the dataset
"""
data_crs = dataset_crs(ds)

transformer = transformer_from_crs(
crs_from=geometry_crs,
crs_to=data_crs,
always_xy=True,
mpiannucci marked this conversation as resolved.
Show resolved Hide resolved
)
return transform(transformer.transform, geometry)


def project_dataset(ds: xr.Dataset, query_crs: str) -> xr.Dataset:
"""
Project the dataset to the given CRS
"""
data_crs = dataset_crs(ds)
target_crs = pyproj.CRS.from_string(query_crs)
if data_crs == target_crs:
mpiannucci marked this conversation as resolved.
Show resolved Hide resolved
return ds

transformer = transformer_from_crs(
crs_from=data_crs,
crs_to=target_crs,
always_xy=True,
)

# TODO: Handle rotated pole
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i thought we figured this out.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only matters if the user has input a rotated coord system as the target coordinate system. Which i may be comfortable ignoring for now...

cf_coords = target_crs.coordinate_system.to_cf()

# Get the new X and Y coordinates
target_y_coord = next(coord for coord in cf_coords if coord["axis"] == "Y")
target_x_coord = next(coord for coord in cf_coords if coord["axis"] == "X")

X = ds.cf["X"]
Y = ds.cf["Y"]

# Transform the coordinates
# If the data is vectorized, we just transform the points in full
# TODO: Handle 2D coordinates
if len(X.dims) > 1 or len(Y.dims) > 1:
raise NotImplementedError("Only 1D coordinates are supported")

x_dim = X.dims[0]
y_dim = Y.dims[0]
if x_dim == VECTORIZED_DIM and y_dim == VECTORIZED_DIM:
x = X
y = Y
target_dims: tuple = (VECTORIZED_DIM,)
else:
# Otherwise we need to transform the full grid
# TODO: Is there a better way to handle this? this is quite hacky
var = [d for d in ds.data_vars if x_dim and y_dim in ds[d].dims][0]
mpiannucci marked this conversation as resolved.
Show resolved Hide resolved
var_dims = ds[var].dims
if var_dims.index(x_dim) < var_dims.index(y_dim):
x, y = xr.broadcast(X, Y)
target_dims = (x_dim, y_dim)
else:
x, y = xr.broadcast(Y, X)
target_dims = (y_dim, x_dim)
mpiannucci marked this conversation as resolved.
Show resolved Hide resolved

x, y = transformer.transform(x, y)

coords_to_drop = [
c for c in ds.coords if x_dim in ds[c].dims or y_dim in ds[c].dims
]

target_x_coord_name = target_x_coord["standard_name"]
target_y_coord_name = target_y_coord["standard_name"]

if target_x_coord_name in ds:
target_x_coord_name += "_"
if target_y_coord_name in ds:
target_y_coord_name += "_"
mpiannucci marked this conversation as resolved.
Show resolved Hide resolved

# Create the new dataset with vectorized coordinates
ds = ds.assign_coords(
{
target_x_coord_name: (target_dims, x),
target_y_coord_name: (target_dims, y),
},
)

ds = ds.drop_vars(coords_to_drop)

return ds
6 changes: 4 additions & 2 deletions xpublish_edr/geometry/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,10 @@ def _select_by_multiple_positions_regular_xy_grid(
"""
# Find the nearest X and Y coordinates to the point using vectorized indexing
x, y = np.array(list(zip(*[(point.x, point.y) for point in points.geoms])))
sel_x = xr.Variable(data=x, dims=VECTORIZED_DIM)
sel_y = xr.Variable(data=y, dims=VECTORIZED_DIM)

# When using vectorized indexing with interp, we need to persist the attributes explicitly
sel_x = xr.Variable(data=x, dims=VECTORIZED_DIM, attrs=ds.cf["X"].attrs)
sel_y = xr.Variable(data=y, dims=VECTORIZED_DIM, attrs=ds.cf["Y"].attrs)
if method == "nearest":
return ds.cf.sel(X=sel_x, Y=sel_y, method=method)
else:
Expand Down
4 changes: 2 additions & 2 deletions xpublish_edr/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def get_position(
logger.debug(f"Dataset filtered by query params {ds}")

try:
ds = select_by_position(ds, query.geometry, query.method)
ds = select_by_position(ds, query.project_geometry(ds), query.method)
except KeyError:
raise HTTPException(
status_code=404,
Expand Down Expand Up @@ -148,7 +148,7 @@ def get_area(
logger.debug(f"Dataset filtered by query params {ds}")

try:
ds = select_by_area(ds, query.geometry)
ds = select_by_area(ds, query.project_geometry(ds))
except KeyError:
raise HTTPException(
status_code=404,
Expand Down
20 changes: 15 additions & 5 deletions xpublish_edr/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import xarray as xr
from fastapi import Query
from pydantic import BaseModel, Field
from shapely import wkt
from shapely import Geometry, wkt

from xpublish_edr.geometry.common import project_geometry
from xpublish_edr.logger import logger


Expand All @@ -25,15 +26,24 @@ class EDRQuery(BaseModel):
z: Optional[str] = None
datetime: Optional[str] = None
parameters: Optional[str] = None
crs: Optional[str] = None
crs: str = Field(
"EPSG:4326",
title="Coordinate Reference System",
description="Coordinate Reference System for the query. Default is EPSG:4326",
)
format: Optional[str] = None
method: Literal["nearest", "linear"] = "nearest"

@property
def geometry(self):
def geometry(self) -> Geometry:
"""Shapely point from WKT query params"""
return wkt.loads(self.coords)

def project_geometry(self, ds: xr.Dataset) -> Geometry:
"""Project the geometry to the dataset's CRS"""
geometry = self.geometry
return project_geometry(ds, self.crs, geometry)

def select(self, ds: xr.Dataset, query_params: dict) -> xr.Dataset:
"""Select data from a dataset based on the query"""
if self.z:
Expand Down Expand Up @@ -118,8 +128,8 @@ def edr_query(
alias="parameter-name",
description="xarray variables to query",
),
crs: Optional[str] = Query(
None,
crs: str = Query(
"EPSG:4326",
deprecated=True,
description="CRS is not yet implemented",
),
Expand Down
Loading