From 5ec5c46fd02365ff9170313ed6810bd3e053eadd Mon Sep 17 00:00:00 2001 From: Nicholas Delli Carpini Date: Tue, 7 Nov 2023 12:44:23 -0500 Subject: [PATCH] fix fvcom grid, update grid.select_by_elevation, and fix bad float parsing --- requirements.txt | 1 + xpublish_wms/grid.py | 120 +++++++++++++++++++++------ xpublish_wms/utils.py | 8 ++ xpublish_wms/wms/get_feature_info.py | 11 +-- xpublish_wms/wms/get_legend_info.py | 4 +- xpublish_wms/wms/get_map.py | 6 +- 6 files changed, 111 insertions(+), 39 deletions(-) diff --git a/requirements.txt b/requirements.txt index f7b7ba2..25d7760 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,6 +4,7 @@ cf_xarray datashader matplotlib Pillow +scikit-learn rioxarray>=0.12.2 xarray xpublish>=0.3.0,<0.4.0 diff --git a/xpublish_wms/grid.py b/xpublish_wms/grid.py index 836dd97..be26943 100644 --- a/xpublish_wms/grid.py +++ b/xpublish_wms/grid.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Any, Optional, Tuple, Union +from typing import Any, Optional, Tuple, Union, Sequence import cartopy.geodesic import cf_xarray # noqa @@ -11,6 +11,7 @@ import xarray as xr from xpublish_wms.utils import strip_float, to_mercator +from sklearn.neighbors import BallTree class RenderMethod(Enum): @@ -84,11 +85,21 @@ def elevations(self, da: xr.DataArray) -> Optional[xr.DataArray]: def select_by_elevation( self, da: xr.DataArray, - elevation: float = 0.0, + elevations: Sequence[float], ) -> xr.DataArray: """Select the given data array by elevation""" + + if elevations is None or len(elevations) == 0 or all(v is None for v in elevations): + elevations = [0.0] + if "vertical" in da.cf: - da = da.cf.sel({"vertical": elevation}, method="nearest") + if len(elevations) == 1: + return da.cf.sel(vertical=elevations[0], method="nearest") + elif len(elevations) > 1: + return da.cf.sel(vertical=elevations) + else: + return da.cf.sel(vertical=0, method="nearest") + return da def mask( @@ -365,46 +376,101 @@ def elevation_positive_direction(self, da: xr.DataArray) -> Optional[str]: def elevations(self, da: xr.DataArray) -> Optional[xr.DataArray]: if "vertical" in da.cf: return da.cf["vertical"][:, 0] - elif "siglay" in da.dims: - # Sometimes fvcom variables dont have coordinates assigned correctly, so brute force it - return self.ds.siglay[:, 0] - elif "siglev" in da.dims: + else: # Sometimes fvcom variables dont have coordinates assigned correctly, so brute force it - return self.ds.siglev[:, 0] + vertical_var = None + if "siglay" in da.dims: + vertical_var = "siglay" + elif "siglev" in da.dims: + vertical_var = "siglev" + + if vertical_var is not None: + temp_elevations = self.ds[vertical_var].values[:, 0] + return xr.DataArray( + data=[temp_elevations[i] for i in da[vertical_var]], + dims=da[vertical_var].dims, + coords=da[vertical_var].coords, + name=self.ds[vertical_var].name, + attrs=self.ds[vertical_var].attrs + ) + + return None + + def sel_lat_lng( + self, + subset: xr.Dataset, + lng, + lat, + parameters, + ) -> Tuple[xr.Dataset, list, list]: + """Select the given dataset by the given lon/lat and optional elevation""" + + lng_rad = np.deg2rad(subset.cf["longitude"]) + lat_rad = np.deg2rad(subset.cf["latitude"]) + + stacked = np.stack([lng_rad, lat_rad], axis=-1) + tree = BallTree(stacked, leaf_size=2, metric="haversine") + + idx = tree.query([[np.deg2rad((360 + lng) if lng < 0 else lng), np.deg2rad(lat)]], return_distance=False) + + if 'nele' in subset.dims: + subset = subset.isel(nele=idx[0][0]) else: - return None + subset = subset.isel(node=idx[0][0]) + + x_axis = [strip_float(subset.cf["longitude"])] + y_axis = [strip_float(subset.cf["latitude"])] + return subset, x_axis, y_axis def select_by_elevation( self, da: xr.DataArray, - elevation: Optional[float], + elevations: Optional[Sequence[float]], ) -> xr.DataArray: """Select the given data array by elevation""" - print(da.coords) - print(da.cf) if not self.has_elevation(da): return da - if elevation is None: - elevation = 0.0 + if elevations is None or len(elevations) == 0 or all(v is None for v in elevations): + elevations = [0.0] + + da_elevations = self.elevations(da) + + elevation_index = [int(np.absolute(da_elevations - elevation).argmin().values) for elevation in elevations] + if len(elevation_index) == 1: + elevation_index = elevation_index[0] + + if "vertical" not in da.cf: + if "siglay" in da.dims: + da.__setitem__("siglay", da_elevations) + elif "siglev" in da.dims: + da.__setitem__("siglev", da_elevations) - elevations = self.elevations(da) - diff = np.absolute(elevations - elevation) - elevation_index = int(diff.argmin().values) if "vertical" in da.cf: da = da.cf.isel(vertical=elevation_index) - elif "siglay" in da.dims: - print(elevation_index) - da = da.isel(siglay=elevation_index) - elif "siglev" in da.dims: - da = da.isel(siglev=elevation_index) - - print(da.coords) - print(da.cf) return da def project(self, da: xr.DataArray, crs: str) -> Any: + # fvcom nodal variables have data on both the faces and edges + if 'nele' in da.dims: + elem_count = self.ds.ntve.isel(time=0).values + neighbors = self.ds.nbve.isel(time=0).values + mask = (neighbors[:, :] > 0) + + new_values = np.sum(da.values[neighbors[:, :] - 1], axis=0, where=mask) / elem_count + da = xr.DataArray( + data=new_values, + dims=da.dims, + name=da.name, + attrs=da.attrs, + coords=dict( + lonc=(da.cf["longitude"].dims, self.ds.lon.values, da.cf["longitude"].attrs), + latc=(da.cf["latitude"].dims, self.ds.lat.values, da.cf["latitude"].attrs), + time=da.coords['time'] + ), + ) + if crs == "EPSG:4326": da = da.assign_coords({"x": da.cf["longitude"], "y": da.cf["latitude"]}) elif crs == "EPSG:3857": @@ -528,12 +594,12 @@ def elevations(self, da: xr.DataArray) -> Optional[xr.DataArray]: def select_by_elevation( self, da: xr.DataArray, - elevation: Optional[float], + elevations: Optional[Sequence[float]], ) -> xr.DataArray: if self._grid is None: return None else: - return self._grid.select_by_elevation(da, elevation) + return self._grid.select_by_elevation(da, elevations) def mask( self, diff --git a/xpublish_wms/utils.py b/xpublish_wms/utils.py index 72fbc6c..bebad70 100644 --- a/xpublish_wms/utils.py +++ b/xpublish_wms/utils.py @@ -20,6 +20,14 @@ def strip_float(value): return float(value.values) +def parse_float(value): + if 'e' in value.lower(): + part_arr = value.lower().split("e") + return float(part_arr[0].strip()) * (10 ** float(part_arr[1].strip())) + + return float(value.strip()) + + def round_float_values(v) -> list: if not isinstance(v, list): return round(v, 5) diff --git a/xpublish_wms/wms/get_feature_info.py b/xpublish_wms/wms/get_feature_info.py index f6d071f..1cf0d70 100644 --- a/xpublish_wms/wms/get_feature_info.py +++ b/xpublish_wms/wms/get_feature_info.py @@ -159,14 +159,9 @@ def get_feature_info(ds: xr.Dataset, query: dict) -> Response: if any_has_vertical_axis: if elevation == "all": # Dont select an elevation, just keep all elevation coords - elevation = selected_ds.cf["vertical"].values - elif len(elevation) == 1: - selected_ds = selected_ds.cf.sel(vertical=elevation, method="nearest") - elif len(elevation) > 1: - selected_ds = selected_ds.cf.sel(vertical=slice(elevation[0], elevation[1])) - else: - # Select closest to the surface by default - selected_ds = selected_ds.cf.sel(vertical=0, method="nearest") + elevation = ds.gridded.elevations(selected_ds) + + selected_ds = ds.gridded.select_by_elevation(selected_ds, elevation) try: # Apply masking if necessary diff --git a/xpublish_wms/wms/get_legend_info.py b/xpublish_wms/wms/get_legend_info.py index 5ef07aa..f2c5c55 100644 --- a/xpublish_wms/wms/get_legend_info.py +++ b/xpublish_wms/wms/get_legend_info.py @@ -7,6 +7,8 @@ from matplotlib import cm from PIL import Image +from xpublish_wms.utils import parse_float + def get_legend_info(dataset: xr.Dataset, query: dict) -> Response: """ @@ -18,7 +20,7 @@ def get_legend_info(dataset: xr.Dataset, query: dict) -> Response: vertical = query.get("vertical", "false") == "true" # colorbaronly = query.get("colorbaronly", "False") == "True" colorscalerange = [ - float(x) for x in query.get("colorscalerange", "nan,nan").split(",") + parse_float(x) for x in query.get("colorscalerange", "nan,nan").split(",") ] if isnan(colorscalerange[0]): autoscale = True diff --git a/xpublish_wms/wms/get_map.py b/xpublish_wms/wms/get_map.py index d14433c..dbee864 100644 --- a/xpublish_wms/wms/get_map.py +++ b/xpublish_wms/wms/get_map.py @@ -16,6 +16,7 @@ from fastapi.responses import StreamingResponse from xpublish_wms.grid import RenderMethod +from xpublish_wms.utils import parse_float logger = logging.getLogger("uvicorn") @@ -150,7 +151,7 @@ def ensure_query_types(self, ds: xr.Dataset, query: dict): self.palettename = self.DEFAULT_PALETTE self.colorscalerange = [ - float(x) for x in query.get("colorscalerange", "nan,nan").split(",") + parse_float(x) for x in query.get("colorscalerange", "nan,nan").split(",") ] self.autoscale = query.get("autoscale", "false") == "true" @@ -199,8 +200,7 @@ def select_elevation(self, ds: xr.Dataset, da: xr.DataArray) -> xr.DataArray: :param da: :return: """ - da = ds.gridded.select_by_elevation(da, self.elevation) - print(da.shape) + da = ds.gridded.select_by_elevation(da, [self.elevation]) return da