Skip to content

Commit

Permalink
fix fvcom grid, update grid.select_by_elevation, and fix bad float pa…
Browse files Browse the repository at this point in the history
…rsing
  • Loading branch information
ndellicarpini committed Nov 7, 2023
1 parent e843618 commit 5ec5c46
Show file tree
Hide file tree
Showing 6 changed files with 111 additions and 39 deletions.
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ cf_xarray
datashader
matplotlib
Pillow
scikit-learn
rioxarray>=0.12.2
xarray
xpublish>=0.3.0,<0.4.0
120 changes: 93 additions & 27 deletions xpublish_wms/grid.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Optional, Tuple, Union
from typing import Any, Optional, Tuple, Union, Sequence

import cartopy.geodesic
import cf_xarray # noqa
Expand All @@ -11,6 +11,7 @@
import xarray as xr

from xpublish_wms.utils import strip_float, to_mercator
from sklearn.neighbors import BallTree


class RenderMethod(Enum):
Expand Down Expand Up @@ -84,11 +85,21 @@ def elevations(self, da: xr.DataArray) -> Optional[xr.DataArray]:
def select_by_elevation(
self,
da: xr.DataArray,
elevation: float = 0.0,
elevations: Sequence[float],
) -> xr.DataArray:
"""Select the given data array by elevation"""

if elevations is None or len(elevations) == 0 or all(v is None for v in elevations):
elevations = [0.0]

if "vertical" in da.cf:
da = da.cf.sel({"vertical": elevation}, method="nearest")
if len(elevations) == 1:
return da.cf.sel(vertical=elevations[0], method="nearest")
elif len(elevations) > 1:
return da.cf.sel(vertical=elevations)
else:
return da.cf.sel(vertical=0, method="nearest")

return da

def mask(
Expand Down Expand Up @@ -365,46 +376,101 @@ def elevation_positive_direction(self, da: xr.DataArray) -> Optional[str]:
def elevations(self, da: xr.DataArray) -> Optional[xr.DataArray]:
if "vertical" in da.cf:
return da.cf["vertical"][:, 0]
elif "siglay" in da.dims:
# Sometimes fvcom variables dont have coordinates assigned correctly, so brute force it
return self.ds.siglay[:, 0]
elif "siglev" in da.dims:
else:
# Sometimes fvcom variables dont have coordinates assigned correctly, so brute force it
return self.ds.siglev[:, 0]
vertical_var = None
if "siglay" in da.dims:
vertical_var = "siglay"
elif "siglev" in da.dims:
vertical_var = "siglev"

if vertical_var is not None:
temp_elevations = self.ds[vertical_var].values[:, 0]
return xr.DataArray(
data=[temp_elevations[i] for i in da[vertical_var]],
dims=da[vertical_var].dims,
coords=da[vertical_var].coords,
name=self.ds[vertical_var].name,
attrs=self.ds[vertical_var].attrs
)

return None

def sel_lat_lng(
self,
subset: xr.Dataset,
lng,
lat,
parameters,
) -> Tuple[xr.Dataset, list, list]:
"""Select the given dataset by the given lon/lat and optional elevation"""

lng_rad = np.deg2rad(subset.cf["longitude"])
lat_rad = np.deg2rad(subset.cf["latitude"])

stacked = np.stack([lng_rad, lat_rad], axis=-1)
tree = BallTree(stacked, leaf_size=2, metric="haversine")

idx = tree.query([[np.deg2rad((360 + lng) if lng < 0 else lng), np.deg2rad(lat)]], return_distance=False)

if 'nele' in subset.dims:
subset = subset.isel(nele=idx[0][0])
else:
return None
subset = subset.isel(node=idx[0][0])

x_axis = [strip_float(subset.cf["longitude"])]
y_axis = [strip_float(subset.cf["latitude"])]
return subset, x_axis, y_axis

def select_by_elevation(
self,
da: xr.DataArray,
elevation: Optional[float],
elevations: Optional[Sequence[float]],
) -> xr.DataArray:
"""Select the given data array by elevation"""
print(da.coords)
print(da.cf)
if not self.has_elevation(da):
return da

if elevation is None:
elevation = 0.0
if elevations is None or len(elevations) == 0 or all(v is None for v in elevations):
elevations = [0.0]

da_elevations = self.elevations(da)

elevation_index = [int(np.absolute(da_elevations - elevation).argmin().values) for elevation in elevations]
if len(elevation_index) == 1:
elevation_index = elevation_index[0]

if "vertical" not in da.cf:
if "siglay" in da.dims:
da.__setitem__("siglay", da_elevations)
elif "siglev" in da.dims:
da.__setitem__("siglev", da_elevations)

elevations = self.elevations(da)
diff = np.absolute(elevations - elevation)
elevation_index = int(diff.argmin().values)
if "vertical" in da.cf:
da = da.cf.isel(vertical=elevation_index)
elif "siglay" in da.dims:
print(elevation_index)
da = da.isel(siglay=elevation_index)
elif "siglev" in da.dims:
da = da.isel(siglev=elevation_index)

print(da.coords)
print(da.cf)

return da

def project(self, da: xr.DataArray, crs: str) -> Any:
# fvcom nodal variables have data on both the faces and edges
if 'nele' in da.dims:
elem_count = self.ds.ntve.isel(time=0).values
neighbors = self.ds.nbve.isel(time=0).values
mask = (neighbors[:, :] > 0)

new_values = np.sum(da.values[neighbors[:, :] - 1], axis=0, where=mask) / elem_count
da = xr.DataArray(
data=new_values,
dims=da.dims,
name=da.name,
attrs=da.attrs,
coords=dict(
lonc=(da.cf["longitude"].dims, self.ds.lon.values, da.cf["longitude"].attrs),
latc=(da.cf["latitude"].dims, self.ds.lat.values, da.cf["latitude"].attrs),
time=da.coords['time']
),
)

if crs == "EPSG:4326":
da = da.assign_coords({"x": da.cf["longitude"], "y": da.cf["latitude"]})
elif crs == "EPSG:3857":
Expand Down Expand Up @@ -528,12 +594,12 @@ def elevations(self, da: xr.DataArray) -> Optional[xr.DataArray]:
def select_by_elevation(
self,
da: xr.DataArray,
elevation: Optional[float],
elevations: Optional[Sequence[float]],
) -> xr.DataArray:
if self._grid is None:
return None
else:
return self._grid.select_by_elevation(da, elevation)
return self._grid.select_by_elevation(da, elevations)

def mask(
self,
Expand Down
8 changes: 8 additions & 0 deletions xpublish_wms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def strip_float(value):
return float(value.values)


def parse_float(value):
if 'e' in value.lower():
part_arr = value.lower().split("e")
return float(part_arr[0].strip()) * (10 ** float(part_arr[1].strip()))

return float(value.strip())


def round_float_values(v) -> list:
if not isinstance(v, list):
return round(v, 5)
Expand Down
11 changes: 3 additions & 8 deletions xpublish_wms/wms/get_feature_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,14 +159,9 @@ def get_feature_info(ds: xr.Dataset, query: dict) -> Response:
if any_has_vertical_axis:
if elevation == "all":
# Dont select an elevation, just keep all elevation coords
elevation = selected_ds.cf["vertical"].values
elif len(elevation) == 1:
selected_ds = selected_ds.cf.sel(vertical=elevation, method="nearest")
elif len(elevation) > 1:
selected_ds = selected_ds.cf.sel(vertical=slice(elevation[0], elevation[1]))
else:
# Select closest to the surface by default
selected_ds = selected_ds.cf.sel(vertical=0, method="nearest")
elevation = ds.gridded.elevations(selected_ds)

selected_ds = ds.gridded.select_by_elevation(selected_ds, elevation)

try:
# Apply masking if necessary
Expand Down
4 changes: 3 additions & 1 deletion xpublish_wms/wms/get_legend_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from matplotlib import cm
from PIL import Image

from xpublish_wms.utils import parse_float


def get_legend_info(dataset: xr.Dataset, query: dict) -> Response:
"""
Expand All @@ -18,7 +20,7 @@ def get_legend_info(dataset: xr.Dataset, query: dict) -> Response:
vertical = query.get("vertical", "false") == "true"
# colorbaronly = query.get("colorbaronly", "False") == "True"
colorscalerange = [
float(x) for x in query.get("colorscalerange", "nan,nan").split(",")
parse_float(x) for x in query.get("colorscalerange", "nan,nan").split(",")
]
if isnan(colorscalerange[0]):
autoscale = True
Expand Down
6 changes: 3 additions & 3 deletions xpublish_wms/wms/get_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from fastapi.responses import StreamingResponse

from xpublish_wms.grid import RenderMethod
from xpublish_wms.utils import parse_float

logger = logging.getLogger("uvicorn")

Expand Down Expand Up @@ -150,7 +151,7 @@ def ensure_query_types(self, ds: xr.Dataset, query: dict):
self.palettename = self.DEFAULT_PALETTE

self.colorscalerange = [
float(x) for x in query.get("colorscalerange", "nan,nan").split(",")
parse_float(x) for x in query.get("colorscalerange", "nan,nan").split(",")
]
self.autoscale = query.get("autoscale", "false") == "true"

Expand Down Expand Up @@ -199,8 +200,7 @@ def select_elevation(self, ds: xr.Dataset, da: xr.DataArray) -> xr.DataArray:
:param da:
:return:
"""
da = ds.gridded.select_by_elevation(da, self.elevation)
print(da.shape)
da = ds.gridded.select_by_elevation(da, [self.elevation])

return da

Expand Down

0 comments on commit 5ec5c46

Please sign in to comment.