diff --git a/docs/whats_new.rst b/docs/whats_new.rst index 45d8d85..4c5e5cb 100644 --- a/docs/whats_new.rst +++ b/docs/whats_new.rst @@ -1,6 +1,11 @@ :mod:`What's New` ----------------- +v1.2.0 (September 13, 2023) +=========================== +* Improvements to interpolation + + v1.1.4 (January 27, 2023) ========================= * fixed docs to run fully diff --git a/extract_model/__init__.py b/extract_model/__init__.py index 0d5182e..04a51d5 100644 --- a/extract_model/__init__.py +++ b/extract_model/__init__.py @@ -9,7 +9,8 @@ import extract_model.accessor # noqa: F401 from .extract_model import sel2d, sel2dcf, select, selZ # noqa: F401 -from .utils import filter, order, preprocess, sub_bbox, sub_grid # noqa: F401 +from .preprocessing import preprocess +from .utils import filter, guess_model_type, order, sub_bbox, sub_grid # noqa: F401 try: diff --git a/extract_model/accessor.py b/extract_model/accessor.py index 9ce24ed..7a77e84 100644 --- a/extract_model/accessor.py +++ b/extract_model/accessor.py @@ -262,28 +262,34 @@ def interp2d( latitude=lats, locstream=locstream, weights=weights, + horizontal_interp=True, + horizontal_interp_code="xesmf", iT=iT, T=T, iZ=iZ, Z=Z, extrap=extrap, extrap_val=extrap_val, + return_info=True, ) else: - da, weights = em.select( + da, kwargs_out = em.select( self.da.to_dataset(), longitude=lons, latitude=lats, locstream=locstream, weights=weights, + horizontal_interp=True, + horizontal_interp_code="xesmf", iT=iT, T=T, iZ=iZ, Z=Z, extrap=extrap, extrap_val=extrap_val, + return_info=True, ) - self.weights_map[hashenc] = weights + self.weights_map[hashenc] = kwargs_out["weights"] return da[varname] diff --git a/extract_model/extract_model.py b/extract_model/extract_model.py index 75ad0af..04716c2 100644 --- a/extract_model/extract_model.py +++ b/extract_model/extract_model.py @@ -15,6 +15,8 @@ from dask.delayed import Delayed from xarray import DataArray, Dataset +from .utils import calc_barycentric, interp_with_barycentric, order, tree_query + try: import xesmf as xe @@ -35,10 +37,10 @@ def interp_multi_dim( da: DataArray, da_out: Optional[DataArray] = None, - T: Optional[Union[str, list]] = None, - Z: Optional[Union[str, list]] = None, - iT: Optional[Union[int, list]] = None, - iZ: Optional[Union[int, list]] = None, + # T: Optional[Union[str, list]] = None, + # Z: Optional[Union[str, list]] = None, + # iT: Optional[Union[int, list]] = None, + # iZ: Optional[Union[int, list]] = None, extrap_method: Optional[str] = None, locstream: bool = False, weights=None, @@ -67,6 +69,7 @@ def interp_multi_dim( raise ModuleNotFoundError( # pragma: no cover "xESMF is not available so horizontal interpolation in 2D cannot be performed." ) + da = da.chunk({da.cf["Y"].name: -1, da.cf["X"].name: -1}) # set up regridder, use weights if available regridder = xe.Regridder( @@ -184,8 +187,18 @@ def select( iZ=None, extrap=False, extrap_val=None, + horizontal_interp: bool = False, + horizontal_interp_code: str = "xesmf", + use_projection: bool = True, + triangulation=None, + mask: Optional[DataArray] = None, + use_xoak: bool = False, + vertical_interp: bool = False, + xgcm_grid=None, locstream=False, weights=None, + make_time_series=False, + return_info: bool = False, ): """Extract output from da at location(s). @@ -196,7 +209,9 @@ def select( longitude, latitude: int, float, list, array (1D or 2D), DataArray, optional longitude(s), latitude(s) at which to return model output. Package `xESMF` will be used to interpolate with "bilinear" to - these horizontal locations. + these horizontal locations if horizontal_interp is True. If + horizontal_interp is False, then nearest neighbors will be found with + sel2d. T: datetime-like string, list of datetime-like strings, optional Datetime or datetimes at which to return model output. `xarray`'s built-in 1D interpolation will be used to calculate. @@ -223,6 +238,23 @@ def select( extrap_val: int, float, optional If `extrap==False`, values outside domain will be returned as 0, or as `extrap_value` if input. + horizontal_interp: bool + True to interpolate, False to find nearest neighbor with sel2d. Defaults to False. + horizontal_interp_code: str + Default "xesmf" to use package ``xESMF`` for horizontal interpolation, which is probably better if you need to interpolate to many points. To use ``xESMF`` you have install it as an optional dependency. Input "tree" to use BallTree to find nearest 3 neighbors and interpolate using barycentric coordinates. This has been tested for interpolating to 3 locations so far. Input "delaunay" to use a delaunay triangulation to find the nearest triangle points and interpolate the same as with "tree" using barycentric coordinates. This should be faster when you have more points to interpolate to, especially if you save and reuse the triangulation. + use_projection: bool + If True, project coordinates using Albers Equal Area projection for running interpolation. Only applies for ``horizontal_interp_code="tree"`` and ``horizontal_interp_code="delaunay"``. Otherwise keep in long/lat. + triangulation: + Input to reuse Delaunary triangulation made up of the lon/lat grid for interpolation. Use with ``horizontal_interp_code="delaunay"``. + mask : DataArray, optional + Mask associated with da to be used with nearest neighbor search, if desired, so that + masked elements are not returned. + use_xoak : bool + If True, use xoak to find nearest 1 point. If False, use BallTree directly to find distances and nearest 4 points. + vertical_interp : bool + If True, interpolate in depth to input Z values. If False, use `method="nearest"` to return nearest depths to Z. + xgcm_grid + If you want to interpolate vertically in more than 1 dimension, you need to have an xGCM grid defined that tells xgcm how to interpret the grid vertically. You can get this for known models using, e.g., ``preprocess_roms_grid()``. locstream: boolean, optional Which type of interpolation to do: @@ -231,10 +263,14 @@ def select( weights: xESMF netCDF file path, DataArray, optional If a weights file or array exists you can pass it as an argument here and it will be re-used. + make_time_series: bool + If True, use advanced indexing to select final result down to time series matching da.cf["T"] length. + return_info: bool + If True, return a dict of extra information that depends on what processes were run. Returns ------- - DataArray of interpolated and/or selected values from da, and array of weights. + DataArray of interpolated and/or selected values from da, and kwargs_out which contains different things depending on the code that was run. Examples -------- @@ -250,6 +286,8 @@ def select( >>> da_out = em.select(**kwargs) """ + kwargs_out = {} + # Must select or interpolate for depth and time. # - i.e. One cannot run in both Z and iZ mode, same for T/iT if (Z is not None) and (iZ is not None): @@ -265,10 +303,10 @@ def select( latitude = [latitude] longitude = np.asarray(longitude) latitude = np.asarray(latitude) - # If longitude and latitude are input, perform horizontal interpolation - horizontal_interp = True - else: - horizontal_interp = False + # # If longitude and latitude are input, perform horizontal interpolation + # horizontal_interp = True + # else: + # horizontal_interp = False # If extrapolating, define method if extrap: @@ -298,26 +336,193 @@ def select( # Perform interpolation if horizontal_interp: - # create Dataset to interpolate to - ds_out = make_output_ds(longitude, latitude, locstream=locstream) - if XESMF_AVAILABLE: + from time import time + + start_time = time() + + if XESMF_AVAILABLE and horizontal_interp_code == "xesmf": + # create Dataset to interpolate to + ds_out = make_output_ds(longitude, latitude, locstream=locstream) + da, weights = interp_multi_dim( da, ds_out, - T=T, - Z=Z, - iT=iT, - iZ=iZ, + # T=T, + # Z=Z, + # iT=iT, + # iZ=iZ, extrap_method=extrap_method, locstream=locstream, weights=weights, ) + kwargs_out["weights"] = weights + + elif horizontal_interp_code == "delaunay": + + # This might be faster for interpolation between a few points and many points + # get points for trianlges from delaunay + from scipy.spatial import Delaunay + + # calculate triangulation + X = np.stack( + [np.ravel(c) for c in [da.cf["longitude"], da.cf["latitude"]]] + ).T + if triangulation is not None: + tri = triangulation + else: + tri = Delaunay(X) + + # save triangulation for potential future use + kwargs_out["tri"] = tri + + # prep points to interpolate to + p = np.stack([np.ravel(c) for c in [longitude, latitude]]).T + + # itri are the triangle indices containing the points we are interpolating to + itri = tri.find_simplex(p) + + # xs, ys are the vertices of the triangles around the points, npts x 3 + # xs, ys, x, y here are actually longs and lats but will be overwritten below if + # use_projection + xs, ys = X[tri.simplices[itri]][:, :, 0], X[tri.simplices[itri]][:, :, 1] + x, y = longitude, latitude + iys, ixs = np.unravel_index(tri.simplices[itri], da.cf["longitude"].shape) + + if use_projection: + # convert to projected coordinates + import pyproj + + min_lat, max_lat = float(da.cf["latitude"].min()), float( + da.cf["latitude"].max() + ) + mean_lon = float(da.cf["longitude"].mean()) + proj = pyproj.Proj( + f"+proj=aea +lat_1={max_lat} +lat_2={min_lat} +lon_0={mean_lon}" + ) + + xs, ys = proj(xs, ys) + x, y = proj(longitude, latitude) + + lam = calc_barycentric(x, y, xs, ys) + # interp_coords are the coords and indices that went into the interpolation + da, interp_coords = interp_with_barycentric(da, ixs, iys, lam) + kwargs_out["interp_coords"] = interp_coords + + elif horizontal_interp_code == "tree": + + # get points from balltree + npts = len(longitude) + + # distances, iys_init, ixs_init are npts x k=6 + distances, (iys_init, ixs_init) = tree_query( + da.cf["longitude"], da.cf["latitude"], longitude, latitude, k=4 + ) + + if use_projection: + import pyproj + + # convert to projected coordinates + min_lat, max_lat = float(da.cf["latitude"].min()), float( + da.cf["latitude"].max() + ) + mean_lon = float(da.cf["longitude"].mean()) + proj = pyproj.Proj( + f"+proj=aea +lat_1={max_lat} +lat_2={min_lat} +lon_0={mean_lon}" + ) + + # # triangle points, projected, npts x 3 + # xs_init, ys_init = proj(da.cf["longitude"].values[iys_init, ixs_init], da.cf["latitude"].values[iys_init, ixs_init]) + # points to find + x, y = proj(longitude, latitude) + else: + x, y = longitude, latitude + + # Choose the first three points that actually make a triangle around each point + from shapely.geometry import Point, Polygon + + # xs, ys store the 3 nearest points making up triangles around each of npts + # npts x 3 + xs, ys = np.empty((npts, 3)), np.empty((npts, 3)) + ixs, iys = np.empty((npts, 3), dtype=int), np.empty((npts, 3), dtype=int) + + # check combinations of 3 points for each npt to find triangle + import itertools + + for ipt in range(npts): + # unique combinations of 3 of the nearest points to ipt, by index + ix_combinations = list(itertools.combinations(ixs_init[ipt], 3)) + iy_combinations = list(itertools.combinations(iys_init[ipt], 3)) + + # point we are interpolating to + # if use_projection, in projected coordinates, otherwise in lon/lat + pt = Point(x[ipt], y[ipt]) + + def pt_in_itriangle(ix, iy): + """projected triangle points for combinations""" + xs, ys = ( + da.cf["longitude"].values[iy, ix], + da.cf["latitude"].values[iy, ix], + ) + return Polygon(np.stack([xs, ys]).T).contains(pt) + + def pt_in_itriangle_proj(ix, iy): + """projected triangle points for combinations""" + xs, ys = proj( + da.cf["longitude"].values[iy, ix], + da.cf["latitude"].values[iy, ix], + ) + return Polygon(np.stack([xs, ys]).T).contains(pt) + + i = 0 + if use_projection: + while not pt_in_itriangle_proj( + ix_combinations[i], iy_combinations[i] + ): + i += 1 + else: + while not pt_in_itriangle(ix_combinations[i], iy_combinations[i]): + i += 1 + + # also update ixs, iys + ixs[ipt], iys[ipt] = ix_combinations[i], iy_combinations[i] + if use_projection: + xs[ipt], ys[ipt] = proj( + da.cf["longitude"].values[iys[ipt], ixs[ipt]], + da.cf["latitude"].values[iys[ipt], ixs[ipt]], + ) + else: + xs[ipt], ys[ipt] = ( + da.cf["longitude"].values[iys[ipt], ixs[ipt]], + da.cf["latitude"].values[iys[ipt], ixs[ipt]], + ) + + lam = calc_barycentric(x, y, xs, ys) + # interp_coords are the coords and indices that went into the interpolation + da, interp_coords = interp_with_barycentric(da, ixs, iys, lam) + kwargs_out["interp_coords"] = interp_coords + else: raise ModuleNotFoundError( "xESMF is not available so horizontal interpolation in 2D cannot be performed." ) + end_time = time() + print("time: ", start_time - end_time) + # nearest neighbor instead + elif not horizontal_interp and longitude is not None and latitude is not None: + da, k_out = da.em.sel2dcf( + longitude=longitude, + latitude=latitude, + mask=mask, + use_xoak=use_xoak, + return_info=True, + ) + kwargs_out["distances"] = k_out["distances"] + kwargs_out.update(k_out) + else: + print("no horizontal grid changes") + if iT is not None: with xr.set_options(keep_attrs=True): da = da.cf.isel(T=iT) @@ -331,50 +536,120 @@ def select( with xr.set_options(keep_attrs=True): da = da.cf.isel(Z=iZ) - # deal with interpolation in Z separately elif Z is not None: - # can do interpolation in depth for any number of dimensions if the - # vertical coord is 1d - if da.cf["vertical"].ndim == 1: - da = da.cf.interp(vertical=Z) + # deal with interpolation in Z separately + if vertical_interp: + # can do interpolation in depth for any number of dimensions if the + # vertical coord is 1d + if da.cf["vertical"].ndim == 1: + if da.cf["vertical"].name in ("s_rho", "s_w"): + raise UserWarning( + f"The dimension identified as the vertical coordinate is {da.cf['vertical'].name} which might not be correct." + ) - # if the vertical coord is greater than 1D, can only do restricted interpolation - # at the moment - else: - da = da.squeeze() - if len(da.dims) == 1 and da.cf["Z"].name in da.dims: - da = da.swap_dims({da.cf["Z"].name: da.cf["vertical"].name}) da = da.cf.interp(vertical=Z) - elif len(da.dims) == 2 and da.cf["Z"].name in da.dims: - # loop over other dimension - dim_var_name = list(set(da.dims) - set([da.cf["Z"].name]))[0] - new_da = [] - for i in range(len(da[dim_var_name])): - new_da.append( - da.isel({dim_var_name: i}) - .swap_dims({da.cf["Z"].name: da.cf["vertical"].name}) - .cf.interp(vertical=Z) + + else: + # # use xgcm + # from xgcm import Grid + # grid = Grid(ds, coords={'Z': {'center':'s_rho', 'outer':'s_w'}, + # }, + # periodic=False + # ) + + # need "grid" from xgcm set up in preprocessing + if xgcm_grid is None: + raise KeyError( + "Need xgcm 'grid' object set up and input as ``xgcm_grid``." ) - da = xr.concat(new_da, dim=dim_var_name) - elif len(da.dims) > 2: - # need to implement (x)isoslice here - raise NotImplementedError( - "Currently it is not possible to interpolate in depth with more than 1 other (time) dimension." + z_attrs = da.cf["vertical"].attrs + z_attrs.update({"axis": "Z"}) + zkey = da.cf["vertical"].name + da = xgcm_grid.transform( + da, "Z", np.array(Z), target_data=da.cf["vertical"], method="linear" ) + da[zkey].attrs = z_attrs + da = order(da) # reorder dimensions to convention + + # # if the vertical coord is greater than 1D, can only do restricted interpolation + # # at the moment + # else: + # da = da.squeeze() + # if len(da.dims) == 1 and da.cf["Z"].name in da.dims: + # da = da.swap_dims({da.cf["Z"].name: da.cf["vertical"].name}) + # da = da.cf.interp(vertical=Z) + # elif len(da.dims) == 2 and da.cf["Z"].name in da.dims: + # # loop over other dimension + # dim_var_name = list(set(da.dims) - set([da.cf["Z"].name]))[0] + # new_da = [] + # for i in range(len(da[dim_var_name])): + # new_da.append( + # da.isel({dim_var_name: i}) + # .swap_dims({da.cf["Z"].name: da.cf["vertical"].name}) + # .cf.interp(vertical=Z) + # ) + # da = xr.concat(new_da, dim=dim_var_name) + # elif len(da.dims) > 2: + # # need to implement (x)isoslice here + # raise NotImplementedError( + # "Currently it is not possible to interpolate in depth with more than 1 other (time) dimension." + # ) + # or just grab nearest + else: + # this works if Z is depth instead of sigma + da = da.cf.sel(Z=Z, method="nearest") + # this is for needing to use vertical to access depths + # da = da.squeeze() + # if da.cf["vertical"].ndim == 1 and da.cf["vertical"].name in da.coords: + # da = da.swap_dims({da.cf["Z"].name: da.cf["vertical"].name}) + # da = da.cf.sel(Z=Z) + # elif len(da.dims) == 2 and da.cf["Z"].name in da.dims: + # # loop over other dimension + # dim_var_name = list(set(da.dims) - set([da.cf["Z"].name]))[0] + # new_da = [] + # for i in range(len(da[dim_var_name])): + # new_da.append( + # da.isel({dim_var_name: i}) + # .swap_dims({da.cf["Z"].name: da.cf["vertical"].name}) + # .cf.interp(vertical=Z) + # ) + # da = xr.concat(new_da, dim=dim_var_name) + + # advanced indexing to select all assuming coherent time series + # make sure len of each dimension matches + if make_time_series: + + dims_to_index = [da.cf["T"].name] + ntimes = len(da.cf["T"]) + for axis, var_names in da.cf.axes.items(): + for var_name in var_names: + if len(da[var_name]) == ntimes and var_name in da.dims: + dims_to_index.append(var_name) + + # use time dim as dim for all since treating as time series + indexer = { + dim: xr.DataArray(np.arange(0, ntimes), dims=da.cf["T"].name) + for dim in dims_to_index + } + da = da.isel(indexer) if extrap_val is not None: # returns 0 outside the domain by default. Assumes that no other values are exactly 0 # and replaces all 0's with extrap_val if chosen. da = da.where(da != 0, extrap_val) - return da.squeeze(), weights + if return_info: + return da.squeeze(), kwargs_out + else: + return da.squeeze() def sel2d( var, mask: Optional[DataArray] = None, - distances_name: Optional[str] = None, + use_xoak: bool = True, + return_info: bool = False, **kwargs, ): """Find the value of the var at closest location to inputs, optionally respecting mask. @@ -402,12 +677,16 @@ def sel2d( A Dataset will "remember" the index calculated for whichever grid coordinates were first requested and subsequently run faster for requests on that grid (and not run for other grids). mask : DataArray, optional If input, mask is applied to lon/lat so that if requested lon/lat is on land, the nearest valid model point will be returned. Otherwise nan's will be returned. If requested lon/lat is outside domain but not on land, the nearest model output will be returned regardless. - distances_name : str, optional - Provide a name in which to save the distances from xoak; there will be one value per lon/lat location found. If None, distances won't be returned in object. + use_xoak : bool + If True, use xoak to find nearest 1 point. If False, use BallTree directly to find distances and nearest 4 points. + return_info: bool + If True, return a dict of extra information that depends on what processes were run. Returns ------- - An xarray object of the same type as input as var which is selected in horizontal coordinates to input locations and, in input, to time and vertical selections. If not selected, other dimensions are brought along. If distances_name is not None, Dataset is returned. + An xarray object of the same type as input as var which is selected in horizontal coordinates to input locations and, in input, to time and vertical selections. If not selected, other dimensions are brought along. Other items returned in kwargs include: + + * distances: the distances from the requested points to the returned nearest points Notes ----- @@ -429,6 +708,12 @@ def sel2d( >>> em.sel2d(da, lon_rho=-96, lat_rho=27, s_rho=0.0, method='nearest') """ + # if dask-backed, read lon/lat into memory for faster working later + if var.cf["longitude"].chunks is not None: + var[var.cf["longitude"].name] = var.cf["longitude"].load() + if var.cf["latitude"].chunks is not None: + var[var.cf["latitude"].name] = var.cf["latitude"].load() + # assign input lon/lat coord names kwargs_iter = iter(kwargs) # first key is assumed longitude unless has "lat" in the name @@ -453,93 +738,193 @@ def sel2d( elif isinstance(lons, list) and isinstance(lats, list): lons, lats = np.array(lons), np.array(lats) - # 1D or 2D - if lons.ndim == lats.ndim == 1: - dims = ("loc",) - elif lons.ndim == lats.ndim == 2: - dims = ("loc_y", "loc_x") - # else: Raise exception - - # create Dataset - ds_to_find = xr.Dataset( - { - "lat_to_find": (dims, lats, {"standard_name": "latitude"}), - "lon_to_find": (dims, lons, {"standard_name": "longitude"}), - } - ) + if use_xoak: - if mask is not None: + # 1D or 2D + if lons.ndim == lats.ndim == 1: + dims = ("loc",) + elif lons.ndim == lats.ndim == 2: + dims = ("loc_y", "loc_x") + # else: Raise exception + + # create Dataset + ds_to_find = xr.Dataset( + { + "lat_to_find": (dims, lats, {"standard_name": "latitude"}), + "lon_to_find": (dims, lons, {"standard_name": "longitude"}), + } + ) - # Assume mask is 2D — but not true for wetting/drying + if mask is not None: - # find indices representing mask - eta, xi = np.where(mask.values) + # if dask-backed, read into memory + if mask.chunks is not None: + mask = mask.load() - # make advanced indexer to flatten arrays - var_flat = var.cf.isel( - X=xr.DataArray(xi, dims="loc"), Y=xr.DataArray(eta, dims="loc") - ) + # Assume mask is 2D — but not true for wetting/drying + # import pdb; pdb.set_trace() + # find indices representing mask + eta, xi = np.where(mask.values) - var = var_flat.copy() + # make advanced indexer to flatten arrays + var_flat = var.cf.isel( + X=xr.DataArray(xi, dims="loc"), Y=xr.DataArray(eta, dims="loc") + ) - if var.xoak.index is None: - var.xoak.set_index([latname, lonname], "sklearn_geo_balltree") - elif (latname, lonname) != var.xoak._index_coords: - raise ValueError( - f"Index has been built for grid with coords {var.xoak._index_coords} but coord names input are ({latname}, {lonname})." - ) - elif var.xoak.index is not None: - pass - else: - warnings.warn( - "Maybe a mask is not present or being properly identified in var. You could use `use_mask=False`.", - RuntimeWarning, + var = var_flat.copy() + + if var.xoak.index is None: + var.xoak.set_index([latname, lonname], "sklearn_geo_balltree") + elif (latname, lonname) != var.xoak._index_coords: + raise ValueError( + f"Index has been built for grid with coords {var.xoak._index_coords} but coord names input are ({latname}, {lonname})." + ) + elif var.xoak.index is not None: + pass + else: + warnings.warn( + "Maybe a mask is not present or being properly identified in var. You could use `use_mask=False`.", + RuntimeWarning, + ) + + # perform selection + output = var.xoak.sel( + {latname: ds_to_find.lat_to_find, lonname: ds_to_find.lon_to_find} ) - # perform selection - output = var.xoak.sel( - {latname: ds_to_find.lat_to_find, lonname: ds_to_find.lon_to_find} - ) - # # this version is for the updates in xoak - # output = var.xoak.sel( - # {latname: ds_to_find.lat_to_find, lonname: ds_to_find.lon_to_find}, - # distances_name=distances_name, - # ) - # output[distances_name] *= 6371 # convert from radians to km - - if distances_name is not None: + # # this version is for the updates in xoak + # output = var.xoak.sel( + # {latname: ds_to_find.lat_to_find, lonname: ds_to_find.lon_to_find}, + # distances_name=distances_name, + # ) + # output[distances_name] *= 6371 # convert from radians to km + # only calculate distances this way (outside of xoak itself) if not doing 2D since we just need distances # right now for OMSA and don't want a separate soln from xoak for this problem if ds_to_find.lat_to_find.ndim > 1: with xr.set_options(keep_attrs=True): - return output.sel(**kwargs) + if return_info: + # no info to return but set up for future use + kwargs_out = {} + return output.sel(**kwargs), kwargs_out + else: + return output.sel(**kwargs) # distances between input points and nearest points - this won'tbe needed with new version of xoak once merged # * 6371 to convert from radians to km index = var.xoak._index if isinstance(index, tuple): index = index[0] - distances = index.query(np.array([*zip(lats, lons)]))["distances"][:, 0] * 6371 + query = index.query(np.array([*zip(lats, lons)])) + from dask.delayed import Delayed + + if isinstance(query, Delayed): + query = query.compute() + distances = query["distances"][:, 0] * 6371 + iflat = query["indices"][:, 0].tolist()[0] + if mask is None and var.cf["X"].ndim == 1: + xi, eta = np.meshgrid(var.cf["X"], var.cf["Y"]) + xi, eta = xi.flatten(), eta.flatten() + + # this is probably wrong if you aren't looking for matches at a single point + ixi = xi[iflat] + ieta = eta[iflat] + if isinstance(distances, Delayed): - # import pdb; pdb.set_trace() distances = distances.compute() if not isinstance(output, Dataset): output = output.to_dataset() attrs = {"units": "km"} indexer_dim = ds_to_find.lat_to_find.dims indexer_shape = ds_to_find.lat_to_find.shape - output[distances_name] = xr.Variable( - indexer_dim, distances.reshape(indexer_shape), attrs - ) + distances = xr.Variable(indexer_dim, distances.reshape(indexer_shape), attrs) + # kwargs["distances"] = distances + + with xr.set_options(keep_attrs=True): + output = output.sel(**kwargs) + kwargs["distances"] = distances + if return_info: + return output, kwargs + # return output.sel(**kwargs), kwargs + else: + return output + + else: - with xr.set_options(keep_attrs=True): - return output.sel(**kwargs) + # currently lons, lats 1D only + + # if no mask, assume user just wants 1 nearest point to each input lons/lats pair + # probably should expand this later to be more generic + if mask is None: + k = 1 + # if user inputs mask, use it to only return the nearest point that is active + # so, find nearest 30 points to have options + else: + k = 30 + + distances, (iys, ixs) = tree_query(var[lonname], var[latname], lons, lats, k=k) + + # sort mask such that active elements are preferentially to the left in a 2D array + if mask is not None and mask.values[iys, ixs].sum() == 0: + raise ValueError("all found values are masked!") + + if mask is not None: + isorted_mask = np.argsort(-mask.values[iys, ixs], axis=-1) + # sort the ixs and iys according to this sorting so that if there are unmasked indices, + # they are leftmost also, and we will use the leftmost values. + ixs_brought_along = np.take_along_axis(ixs, isorted_mask, axis=1) + iys_brought_along = np.take_along_axis(iys, isorted_mask, axis=1) + distances_brought_along = np.take_along_axis( + distances, isorted_mask, axis=1 + ) + ixs0 = ixs_brought_along[:, 0] + iys0 = iys_brought_along[:, 0] + distances0 = distances_brought_along[:, 0] + + # ipoint = 0 + # # if inds[0] is masked, check the next until finding one that isn't mask + # # if no mask provided, just take first point + # if mask is not None: + # def is_masked(iy, ix): + # return int(mask[iy, ix]) == 0 + + # while is_masked(iys[ipoint], ixs[ipoint]): + # ipoint += 1 + + kwargs[var.cf["Y"].name], kwargs[var.cf["X"].name] = iys0, ixs0 + + dims = ("npts",) + var_out = var.cf.isel( + X=xr.DataArray(ixs0, dims=dims), Y=xr.DataArray(iys0, dims=dims) + ) + # add "X" axis to npts + var_out["npts"] = ("npts", var_out.npts.values, {"axis": "X"}) + + kwargs["distances"] = distances0 * 6371 + + else: + kwargs[var.cf["Y"].name], kwargs[var.cf["X"].name] = iys[0], ixs[0] + + dims = ("npts",) + var_out = var.cf.isel( + X=xr.DataArray(ixs[0], dims=dims), Y=xr.DataArray(iys[0], dims=dims) + ) + # add "X" axis to npts + var_out["npts"] = ("npts", var_out.npts.values, {"axis": "X"}) + + kwargs["distances"] = distances * 6371 + + with xr.set_options(keep_attrs=True): + if return_info: + return var_out, kwargs + else: + return var_out def sel2dcf( var, mask: Optional[DataArray] = None, - distances_name: Optional[str] = None, + return_info: bool = False, **kwargs, ): """Find nearest value(s) on 2D horizontal grid using cf-xarray names. @@ -594,7 +979,7 @@ def sel2dcf( new_kwargs.update(kwargs) - return sel2d(var, mask=mask, distances_name=distances_name, **new_kwargs) + return sel2d(var, mask=mask, return_info=return_info, **new_kwargs) def selZ(var, depths): @@ -631,7 +1016,6 @@ def selZ(var, depths): .swap_dims({var.cf["Z"].name: var.cf["vertical"].name}) .cf.sel(vertical=depths, method="nearest") ) - # import pdb; pdb.set_trace() out = xr.concat(new_results, dim=dim_var_name) else: raise NotImplementedError("Sorry only works for 1D and 2D so far.") diff --git a/extract_model/preprocessing.py b/extract_model/preprocessing.py new file mode 100644 index 0000000..4e842ac --- /dev/null +++ b/extract_model/preprocessing.py @@ -0,0 +1,378 @@ +"""Preprocessing-related functions for model output.""" + + +from typing import Optional + +import numpy as np +import xarray as xr + +from extract_model.model_type import ModelType + +from .utils import guess_model_type, order + + +def preprocess_roms( + ds, + grid=None, +): + """Preprocess ROMS model output for use with cf-xarray. + + Also fixes any other known issues with model output. + + Parameters + ---------- + ds: xarray Dataset + + grid: optional + Input xgcm grid to have logic run to make Dataset lazily aware of 4D z_rho and z_w coords on u, v, and psi grids. + + Returns + ------- + Same Dataset but with some metadata added and/or altered. + """ + + rename = {} + if "eta_u" in ds.dims: + rename["eta_u"] = "eta_rho" + if "xi_v" in ds.dims: + rename["xi_v"] = "xi_rho" + if "xi_psi" in ds.dims: + rename["xi_psi"] = "xi_u" + if "eta_psi" in ds.dims: + rename["eta_psi"] = "eta_v" + ds = ds.rename(rename) + + # add axes attributes for dimensions + dims = [dim for dim in ds.dims if dim.startswith("s_")] + for dim in dims: + ds[dim].attrs["axis"] = "Z" + + if "ocean_time" in ds.keys(): + ds.ocean_time.attrs["axis"] = "T" + ds.ocean_time.attrs["standard_name"] = "time" + elif "time" in ds.keys(): + ds.time.attrs["axis"] = "T" + ds.time.attrs["standard_name"] = "time" + + dims = [dim for dim in ds.dims if dim.startswith("xi_")] + # need to also make this a coordinate to add attributes + for dim in dims: + ds[dim] = (dim, np.arange(ds.sizes[dim]), {"axis": "X"}) + + dims = [dim for dim in ds.dims if dim.startswith("eta_")] + for dim in dims: + ds[dim] = (dim, np.arange(ds.sizes[dim]), {"axis": "Y"}) + + # Fix standard_name for s_rho/s_w + if "Vtransform" in ds.data_vars and "s_rho" in ds.coords: + cond1 = ( + ds["Vtransform"] == 1 + and ds["s_rho"].attrs["standard_name"] == "ocean_s_coordinate" + ) + cond2 = ( + ds["Vtransform"] == 2 + and ds["s_rho"].attrs["standard_name"] == "ocean_s_coordinate" + ) + if cond1: + ds["s_rho"].attrs["standard_name"] = "ocean_s_coordinate_g1" + elif cond2: + ds["s_rho"].attrs["standard_name"] = "ocean_s_coordinate_g2" + + cond1 = ( + ds["Vtransform"] == 1 + and ds["s_w"].attrs["standard_name"] == "ocean_s_coordinate" + ) + cond2 = ( + ds["Vtransform"] == 2 + and ds["s_w"].attrs["standard_name"] == "ocean_s_coordinate" + ) + if cond1: + ds["s_w"].attrs["standard_name"] = "ocean_s_coordinate_g1" + elif cond2: + ds["s_w"].attrs["standard_name"] = "ocean_s_coordinate_g2" + + # create vertical coordinates z_rho and z_w + name_dict = {} + if "s_rho" in ds.dims: + name_dict["s_rho"] = "z_rho" + if "positive" in ds.s_rho.attrs: + ds.s_rho.attrs.pop("positive") + if "s_w" in ds.dims: + name_dict["s_w"] = "z_w" + if "positive" in ds.s_w.attrs: + ds.s_w.attrs.pop("positive") + ds.cf.decode_vertical_coords(outnames=name_dict) + + # expand Z coordinates to u and v grids + if grid is not None: + # necessary for interpolating u and v to depths + # ds.coords["z_w"] = order(ds["z_w"]) + # ds.coords["z_w_u"] = grid.interp(ds.z_w.chunk({ds.z_w.cf["X"].name: -1}), "X") + # ds.coords["z_w_u"].attrs = { + # "long_name": "depth of U-points on vertical W grid", + # "time": "ocean_time", + # "field": "z_w_u, scalar, series", + # "units": "m", + # } + # ds.coords["z_w_v"] = grid.interp(ds.z_w.chunk({ds.z_w.cf["Y"].name: -1}), "Y") + # ds.coords["z_w_v"].attrs = { + # "long_name": "depth of V-points on vertical W grid", + # "time": "ocean_time", + # "field": "z_w_v, scalar, series", + # "units": "m", + # } + # ds.coords["z_w_psi"] = grid.interp(ds.z_w_u.chunk({ds.z_w_u.cf["Y"].name: -1}), "Y") + # ds.coords["z_w_psi"].attrs = { + # "long_name": "depth of PSI-points on vertical W grid", + # "time": "ocean_time", + # "field": "z_w_psi, scalar, series", + # "units": "m", + # } + + ds.coords["z_rho"] = order(ds["z_rho"]) + ds.coords["z_rho_u"] = grid.interp( + ds.z_rho.chunk({ds.z_rho.cf["X"].name: -1}), "X" + ) + ds.coords["z_rho_u"].attrs = { + "long_name": "depth of U-points on vertical RHO grid", + "time": "ocean_time", + "field": "z_rho_u, scalar, series", + "units": "m", + } + + ds.coords["z_rho_v"] = grid.interp( + ds.z_rho.chunk({ds.z_rho.cf["Y"].name: -1}), "Y" + ) + ds.coords["z_rho_v"].attrs = { + "long_name": "depth of V-points on vertical RHO grid", + "time": "ocean_time", + "field": "z_rho_v, scalar, series", + "units": "m", + } + + ds.coords["z_rho_psi"] = grid.interp( + ds.z_rho_u.chunk({ds.z_rho_u.cf["Y"].name: -1}), "Y" + ) + ds.coords["z_rho_psi"].attrs = { + "long_name": "depth of PSI-points on vertical RHO grid", + "time": "ocean_time", + "field": "z_rho_psi, scalar, series", + "units": "m", + } + + # will use this to update coordinate encoding + name_dict.update( + {"filler1": "z_rho_u", "filler2": "z_rho_v", "filler3": "z_rho_psi"} + ) # , "None": "z_w_u", "None": "z_w_v", "None": "z_w_psi"}) + + # fix attrs + # for zname in ["z_rho", "z_w"]: + for zname in [var for var in ds.coords if "z_rho" in var or "z_w" in var]: + if zname in ds: + ds[ + zname + ].attrs = {} # coord inherits from one of the vars going into calculation + ds[zname].attrs["positive"] = "up" + ds[zname].attrs["units"] = "m" + ds[zname] = order(ds[zname]) + + # replace s_rho with z_rho, etc, to make z_rho the vertical coord + for sname, zname in name_dict.items(): + for var in ds.data_vars: + if ds[var].ndim == 4: + if "coordinates" in ds[var].encoding: + coords = ds[var].encoding["coordinates"] + if sname in coords: # replace if present + coords = coords.replace(sname, zname) + else: # still add z_rho or z_w + if zname in ds[var].coords and ds[zname].shape == ds[var].shape: + coords += f" {zname}" + ds[var].encoding["coordinates"] = coords + + # # easier to remove "coordinates" attribute from any variables than add it to all + # for var in ds.data_vars: + # if "coordinates" in ds[var].encoding: + # del ds[var].encoding["coordinates"] + + # # add attribute "coordinates" to all variables with at least 2 dimensions + # # and the dimensions have to be the regular types (time, Z, Y, X) + # for var in ds.data_vars: + # if ds[var].ndim >= 2 and (len(set(ds[var].dims) - set([ds[var].cf[axes].name for axes in ds[var].cf.axes])) == 0): + # coords = ['time', 'vertical', 'latitude', 'longitude'] + # var_names = [ds[var].cf[coord].name for coord in coords if coord in ds[var].cf.coords.keys()] + # coord_str = " ".join(var_names) + # ds[var].attrs["coordinates"] = coord_str + + # Add standard_names for typical ROMS variables + # should this not overwrite standard name if it already exists? + var_map = { + "zeta": "sea_surface_elevation", + "salt": "sea_water_practical_salinity", + "temp": "sea_water_temperature", + } + for var_name, standard_name in var_map.items(): + if var_name in ds.data_vars and "standard_name" not in ds[var_name].attrs: + ds[var_name].attrs["standard_name"] = standard_name + + # Fix calendar if wrong + attrs = ds[ds.cf["T"].name].attrs + if ("calendar" in attrs) and (attrs["calendar"] == "gregorian_proleptic"): + attrs["calendar"] = "proleptic_gregorian" + ds[ds.cf["T"].name].attrs = attrs + + if "s_rho" in ds.dims: + if "positive" in ds.s_rho.attrs: + ds.s_rho.attrs.pop("positive") + if "s_w" in ds.dims: + if "positive" in ds.s_w.attrs: + ds.s_w.attrs.pop("positive") + + return ds + + +def preprocess_roms_grid(ds): + # use xgcm + from xgcm import Grid + + coords = { + "X": {"center": "xi_rho", "inner": "xi_u"}, + "Y": {"center": "eta_rho", "inner": "eta_v"}, + "Z": {"center": "s_rho", "outer": "s_w"}, + } + grid = Grid(ds, coords=coords, periodic=False) + return grid + + +def preprocess_fvcom(ds): + """Preprocess FVCOM model output.""" + return ds + + +def preprocess_selfe(ds): + """Preprocess SELFE model output.""" + return ds + + +def preprocess_hycom(ds): + """Preprocess HYCOM model output for use with cf-xarray. + + Also fixes any other known issues with model output. + + Parameters + ---------- + ds: xarray Dataset + + Returns + ------- + Same Dataset but with some metadata added and/or altered. + """ + + if "time" in ds: + ds["time"].attrs["axis"] = "T" + + return ds + + +def preprocess_pom(ds, interp_vertical: bool = True): + """Preprocess POM model output for use with cf-xarray. + + Also fixes any other known issues with model output. + + Parameters + ---------- + ds : xr.Dataset + A dataset containing data described from POM output. + + Returns + ------- + xr.Dataset + Same Dataset but with some metadata added and/or altered. + """ + # The longitude and latitude variables are not recognized as valid coordinates + if "longitude" not in ds.cf.coords: + if "longitude" not in ds.cf.standard_names: + raise ValueError("No variable describing longitude is available.") + + if "latitude" not in ds.cf.standard_names: + raise ValueError("No variable describing latitude is available.") + + ds = ds.cf.set_coords(["latitude", "longitude"]) + + # need to also make this a coordinate to add attributes + ds["nx"] = ("nx", np.arange(ds.sizes["nx"]), {"axis": "X"}) + ds["ny"] = ("ny", np.arange(ds.sizes["ny"]), {"axis": "Y"}) + + # need to add coordinates to each data variable too + for var in ds.data_vars: + if ds[var].ndim == 3: + ds[var].encoding["coordinates"] = "time lat lon" + elif ds[var].ndim == 4: + ds[var].encoding["coordinates"] = "time depth lat lon" + + if interp_vertical: + ds.cf.decode_vertical_coords(outnames={"sigma": "z"}) + + # fix attrs + for zname in ["z"]: # name_dict.values(): + if zname in ds: + ds[ + zname + ].attrs = ( + {} + ) # coord inherits from one of the vars going into calculation + ds[zname].attrs["positive"] = "up" + ds[zname].attrs["units"] = "m" + ds[zname] = order(ds[zname]) + + # keep sigma from showing up as "vertical" in cf-xarray + for sname in ["sigma"]: # name_dict.values(): + if sname in ds: + del ds[sname].attrs["positive"] + + return ds + + +def preprocess_rtofs(ds): + """Preprocess RTOFS model output.""" + + raise NotImplementedError + + +def preprocess(ds, model_type=None, kwargs=None): + """A preprocess function for reading in with xarray. + + This tries to address known model shortcomings in a generic way so that + `cf-xarray` will work generally, including decoding vertical coordinates. + """ + + kwargs = kwargs or {} + + # This is an internal attribute used by netCDF which xarray doesn't know or care about, but can + # be returned from THREDDS. + if "_NCProperties" in ds.attrs: + del ds.attrs["_NCProperties"] + + # Preprocess for all models: if cf-xarray has not identifed axes Z but has identified coordinate vertical + # and the vertical coordinate is 1D, add `axis="Z"` to its attributes so it will also be recognized as + # the Z axes. + if "vertical" in ds.cf.coordinates and "Z" not in ds.cf.axes: + if ds.cf["vertical"].ndim == 1 and len(ds.cf.coordinates["vertical"]) == 1: + key = ds.cf.coordinates["vertical"][0] + ds[key].attrs["axis"] = "Z" + + preprocess_map = { + "ROMS": preprocess_roms, + "FVCOM": preprocess_fvcom, + "SELFE": preprocess_selfe, + "HYCOM": preprocess_hycom, + "POM": preprocess_pom, + "RTOFS": preprocess_rtofs, + } + + if model_type is None: + model_type = guess_model_type(ds) + + if model_type in preprocess_map: + return preprocess_map[model_type](ds, **kwargs) + + return ds diff --git a/extract_model/utils.py b/extract_model/utils.py index 2d1dc5d..b731105 100644 --- a/extract_model/utils.py +++ b/extract_model/utils.py @@ -10,6 +10,8 @@ import numpy as np import xarray as xr +from sklearn.neighbors import BallTree + from extract_model.grids.triangular_mesh import UnstructuredGridSubset from extract_model.model_type import ModelType @@ -482,276 +484,125 @@ def order(da): ) -def preprocess_roms(ds, interp_vertical: bool = True): - """Preprocess ROMS model output for use with cf-xarray. +def tree_query( + lon_coords: xr.DataArray, + lat_coords: xr.DataArray, + lons_to_find: np.array, + lats_to_find: np.array, + k: int = 3, +) -> Tuple[np.array]: + """Set up and query BallTree for k nearest points - Also fixes any other known issues with model output. + Uses haversine for the metric because we are dealing with lon/lat coordinates. Parameters ---------- - ds: xarray Dataset - - interp_vertical=True + lon_coords : xr.DataArray + Longitude coordinates of grid you are searching for nearest points on. + lat_coords : xr.DataArray + Latitude coordinates of grid you are searching for nearest points on. + lons_to_find : np.array + Longitudes of points you are searching for nearest grid points to. + lats_to_find : np.array + Latitudes of points you are searching for nearest grid points to. + k : int, optional + Number of nearest points to return, by default 3 Returns ------- - Same Dataset but with some metadata added and/or altered. + Tuple[np.array] + distances, (iys, ixs) 2D indices for coordinates + + Notes + ----- + Reference: https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.BallTree.html """ - # add axes attributes for dimensions - dims = [dim for dim in ds.dims if dim.startswith("s_")] - for dim in dims: - ds[dim].attrs["axis"] = "Z" - - if "ocean_time" in ds.keys(): - ds.ocean_time.attrs["axis"] = "T" - ds.ocean_time.attrs["standard_name"] = "time" - elif "time" in ds.keys(): - ds.time.attrs["axis"] = "T" - ds.time.attrs["standard_name"] = "time" - - dims = [dim for dim in ds.dims if dim.startswith("xi_")] - # need to also make this a coordinate to add attributes - for dim in dims: - ds[dim] = (dim, np.arange(ds.sizes[dim]), {"axis": "X"}) - - dims = [dim for dim in ds.dims if dim.startswith("eta_")] - for dim in dims: - ds[dim] = (dim, np.arange(ds.sizes[dim]), {"axis": "Y"}) - - # Fix standard_name for s_rho/s_w - if "Vtransform" in ds.data_vars and "s_rho" in ds.coords: - cond1 = ( - ds["Vtransform"] == 1 - and ds["s_rho"].attrs["standard_name"] == "ocean_s_coordinate" - ) - cond2 = ( - ds["Vtransform"] == 2 - and ds["s_rho"].attrs["standard_name"] == "ocean_s_coordinate" - ) - if cond1: - ds["s_rho"].attrs["standard_name"] = "ocean_s_coordinate_g1" - elif cond2: - ds["s_rho"].attrs["standard_name"] = "ocean_s_coordinate_g2" - - cond1 = ( - ds["Vtransform"] == 1 - and ds["s_w"].attrs["standard_name"] == "ocean_s_coordinate" - ) - cond2 = ( - ds["Vtransform"] == 2 - and ds["s_w"].attrs["standard_name"] == "ocean_s_coordinate" - ) - if cond1: - ds["s_w"].attrs["standard_name"] = "ocean_s_coordinate_g1" - elif cond2: - ds["s_w"].attrs["standard_name"] = "ocean_s_coordinate_g2" - - # calculate vertical coord - if interp_vertical: - name_dict = {} - if "s_rho" in ds.dims: - name_dict["s_rho"] = "z_rho" - if "positive" in ds.s_rho.attrs: - ds.s_rho.attrs.pop("positive") - if "s_w" in ds.dims: - name_dict["s_w"] = "z_w" - if "positive" in ds.s_w.attrs: - ds.s_w.attrs.pop("positive") - ds.cf.decode_vertical_coords(outnames=name_dict) - - # fix attrs - for zname in ["z_rho", "z_w"]: # name_dict.values(): - if zname in ds: - ds[ - zname - ].attrs = ( - {} - ) # coord inherits from one of the vars going into calculation - ds[zname].attrs["positive"] = "up" - ds[zname].attrs["units"] = "m" - ds[zname] = order(ds[zname]) - - # replace s_rho with z_rho, etc, to make z_rho the vertical coord - for sname, zname in name_dict.items(): - for var in ds.data_vars: - if ds[var].ndim == 4: - if "coordinates" in ds[var].encoding: - coords = ds[var].encoding["coordinates"] - if sname in coords: # replace if present - coords = coords.replace(sname, zname) - else: # still add z_rho or z_w - if zname in ds.coords and ds[zname].shape == ds[var].shape: - coords += f" {zname}" - ds[var].encoding["coordinates"] = coords - - # # easier to remove "coordinates" attribute from any variables than add it to all - # for var in ds.data_vars: - # if "coordinates" in ds[var].encoding: - # del ds[var].encoding["coordinates"] - - # # add attribute "coordinates" to all variables with at least 2 dimensions - # # and the dimensions have to be the regular types (time, Z, Y, X) - # for var in ds.data_vars: - # if ds[var].ndim >= 2 and (len(set(ds[var].dims) - set([ds[var].cf[axes].name for axes in ds[var].cf.axes])) == 0): - # coords = ['time', 'vertical', 'latitude', 'longitude'] - # var_names = [ds[var].cf[coord].name for coord in coords if coord in ds[var].cf.coords.keys()] - # coord_str = " ".join(var_names) - # ds[var].attrs["coordinates"] = coord_str - - # Add standard_names for typical ROMS variables - # should this not overwrite standard name if it already exists? - var_map = { - "zeta": "sea_surface_elevation", - "salt": "sea_water_practical_salinity", - "temp": "sea_water_temperature", - } - for var_name, standard_name in var_map.items(): - if var_name in ds.data_vars and "standard_name" not in ds[var_name].attrs: - ds[var_name].attrs["standard_name"] = standard_name - - # Fix calendar if wrong - attrs = ds[ds.cf["T"].name].attrs - if ("calendar" in attrs) and (attrs["calendar"] == "gregorian_proleptic"): - attrs["calendar"] = "proleptic_gregorian" - ds[ds.cf["T"].name].attrs = attrs - - return ds - - -def preprocess_fvcom(ds): - """Preprocess FVCOM model output.""" - return ds - - -def preprocess_selfe(ds): - """Preprocess SELFE model output.""" - return ds - - -def preprocess_hycom(ds): - """Preprocess HYCOM model output for use with cf-xarray. - - Also fixes any other known issues with model output. + # create tree + coords = [lon_coords, lat_coords] + X = np.stack([np.ravel(c) for c in coords]).T + tree = BallTree(np.deg2rad(X), metric="haversine") - Parameters - ---------- - ds: xarray Dataset + # set up coordinates we want to search for + coords_to_find = [lons_to_find, lats_to_find] + X_to_find = np.stack([np.ravel(c) for c in coords_to_find]).T - Returns - ------- - Same Dataset but with some metadata added and/or altered. - """ + # query tree + distances, inds = tree.query(np.deg2rad(X_to_find), k=k) - if "time" in ds: - ds["time"].attrs["axis"] = "T" + # convert flat indies to 2D indices + iys, ixs = np.unravel_index(inds, lon_coords.shape) - return ds + return distances, (iys, ixs) -def preprocess_pom(ds, interp_vertical: bool = True): - """Preprocess POM model output for use with cf-xarray. - - Also fixes any other known issues with model output. +def calc_barycentric( + x: np.array, y: np.array, xs: np.array, ys: np.array +) -> xr.DataArray: + """Calculate barycentric weights for npts Parameters ---------- - ds : xr.Dataset - A dataset containing data described from POM output. + x + npts x 1 vector of x locations, can be in lon or projection coordinates. + y + npts x 1 vector of y locations, can be in lat or projection coordinates. + xs + npts x 3 array of triangle x vertices with which to calculate the barycentric weights for each of npts + ys + npts x 3 array of triangle y vertices with which to calculate the barycentric weights for each of npts Returns ------- - xr.Dataset - Same Dataset but with some metadata added and/or altered. + xr.DataArray + Lambda, npts x 3 containing for each of npts the 3 barycentric weights to use for interpolation. """ - # The longitude and latitude variables are not recognized as valid coordinates - if "longitude" not in ds.cf.coords: - if "longitude" not in ds.cf.standard_names: - raise ValueError("No variable describing longitude is available.") - - if "latitude" not in ds.cf.standard_names: - raise ValueError("No variable describing latitude is available.") - - ds = ds.cf.set_coords(["latitude", "longitude"]) - - # need to also make this a coordinate to add attributes - ds["nx"] = ("nx", np.arange(ds.sizes["nx"]), {"axis": "X"}) - ds["ny"] = ("ny", np.arange(ds.sizes["ny"]), {"axis": "Y"}) - - # need to add coordinates to each data variable too - for var in ds.data_vars: - if ds[var].ndim == 3: - ds[var].encoding["coordinates"] = "time lat lon" - elif ds[var].ndim == 4: - ds[var].encoding["coordinates"] = "time depth lat lon" - - if interp_vertical: - ds.cf.decode_vertical_coords(outnames={"sigma": "z"}) - - # fix attrs - for zname in ["z"]: # name_dict.values(): - if zname in ds: - ds[ - zname - ].attrs = ( - {} - ) # coord inherits from one of the vars going into calculation - ds[zname].attrs["positive"] = "up" - ds[zname].attrs["units"] = "m" - ds[zname] = order(ds[zname]) + # barycentric weights + # npts x 1 (vectors) + L1 = ( + (ys[:, 1] - ys[:, 2]) * (x[:] - xs[:, 2]) + + (xs[:, 2] - xs[:, 1]) * (y[:] - ys[:, 2]) + ) / ( + (ys[:, 1] - ys[:, 2]) * (xs[:, 0] - xs[:, 2]) + + (xs[:, 2] - xs[:, 1]) * (ys[:, 0] - ys[:, 2]) + ) + L2 = ( + (ys[:, 2] - ys[:, 0]) * (x[:] - xs[:, 2]) + + (xs[:, 0] - xs[:, 2]) * (y[:] - ys[:, 2]) + ) / ( + (ys[:, 1] - ys[:, 2]) * (xs[:, 0] - xs[:, 2]) + + (xs[:, 2] - xs[:, 1]) * (ys[:, 0] - ys[:, 2]) + ) + L3 = 1 - L1 - L2 - # keep sigma from showing up as "vertical" in cf-xarray - for sname in ["sigma"]: # name_dict.values(): - if sname in ds: - del ds[sname].attrs["positive"] + lam = xr.DataArray(dims=("npts", "triangle"), data=np.vstack((L1, L2, L3)).T) - return ds + return lam -def preprocess_rtofs(ds): - """Preprocess RTOFS model output.""" +def interp_with_barycentric(da, ixs, iys, lam): + vector = da.cf.isel( + X=xr.DataArray(ixs, dims=("npts", "triangle")), + Y=xr.DataArray(iys, dims=("npts", "triangle")), + ) + with xr.set_options(keep_attrs=True): + da = xr.dot(vector, lam, dims=("triangle")) - raise NotImplementedError + # get z coordinates to go with interpolated output if not available + if "vertical" in vector.cf.coords: + zkey = vector.cf["vertical"].name + # only need to interpolate z coordinates if they are not 1D + if vector[zkey].ndim > 1: + da_vert = xr.dot(vector[zkey], lam, dims=("triangle")) -def preprocess(ds, model_type=None, kwargs=None): - """A preprocess function for reading in with xarray. + # add vertical coords into da + da = da.assign_coords({zkey: da_vert}) - This tries to address known model shortcomings in a generic way so that - `cf-xarray` will work generally, including decoding vertical coordinates. - """ + # add "X" axis to npts + da["npts"] = ("npts", da.npts.values, {"axis": "X"}) - kwargs = kwargs or {} - - # This is an internal attribute used by netCDF which xarray doesn't know or care about, but can - # be returned from THREDDS. - if "_NCProperties" in ds.attrs: - del ds.attrs["_NCProperties"] - - # Preprocess for all models: if cf-xarray has not identifed axes Z but has identified coordinate vertical - # and the vertical coordinate is 1D, add `axis="Z"` to its attributes so it will also be recognized as - # the Z axes. - if "vertical" in ds.cf.coordinates and "Z" not in ds.cf.axes: - if ds.cf["vertical"].ndim == 1 and len(ds.cf.coordinates["vertical"]) == 1: - key = ds.cf.coordinates["vertical"][0] - ds[key].attrs["axis"] = "Z" - - preprocess_map = { - "ROMS": preprocess_roms, - "FVCOM": preprocess_fvcom, - "SELFE": preprocess_selfe, - "HYCOM": preprocess_hycom, - "POM": preprocess_pom, - "RTOFS": preprocess_rtofs, - } - - if model_type is None: - model_type = guess_model_type(ds) - - if model_type in preprocess_map: - return preprocess_map[model_type](ds, **kwargs) - - return ds + return da, vector.coords def guess_model_type(ds: xr.Dataset) -> Optional[ModelType]: diff --git a/tests/grids/test_triangular_mesh.py b/tests/grids/test_triangular_mesh.py index 7c1df33..f7adf18 100644 --- a/tests/grids/test_triangular_mesh.py +++ b/tests/grids/test_triangular_mesh.py @@ -7,7 +7,7 @@ import pytest import xarray as xr -from extract_model import utils +from extract_model import preprocessing, utils from extract_model.grids.triangular_mesh import UnstructuredGridSubset @@ -222,7 +222,7 @@ def test_fvcom_preload(real_fvcom): def test_fvcom_preprocess(real_fvcom): - ds = utils.preprocess(real_fvcom) + ds = preprocessing.preprocess(real_fvcom) assert ds is not None @@ -310,7 +310,7 @@ def test_selfe_filter(selfe_data): def test_selfe_preprocess(selfe_data): - ds = utils.preprocess(selfe_data) + ds = preprocessing.preprocess(selfe_data) assert ds is not None diff --git a/tests/test_accessor.py b/tests/test_accessor.py index eb2f818..e6ecf06 100644 --- a/tests/test_accessor.py +++ b/tests/test_accessor.py @@ -36,16 +36,19 @@ def test_2dsel(): inputs = { da.cf["longitude"].name: lon_comp, da.cf["latitude"].name: lat_comp, - "distances_name": "distance", + # "distances_name": "distance", } - da_sel2d = da.em.sel2d(**inputs) + da_sel2d, kwargs_out_sel2d_acc_check = da.em.sel2d(**inputs, return_info=True) da_check = da.cf.isel(X=i, Y=j) + da_sel2d_check = da_sel2d[varname] # checks that the resultant model output is the same - assert np.allclose(da_sel2d[varname].squeeze(), da_check) + assert np.allclose(da_sel2d_check.squeeze(), da_check) - da_test = da.em.sel2dcf( - longitude=lon_comp, latitude=lat_comp, distances_name="distance" + da_test, kwargs_out = da.em.sel2dcf( + longitude=lon_comp, + latitude=lat_comp, + return_info=True, # distances_name="distance" ) assert np.allclose(da_sel2d[varname], da_test[varname]) - assert np.allclose(da_sel2d["distance"], da_test["distance"]) + assert np.allclose(kwargs_out_sel2d_acc_check["distances"], kwargs_out["distances"]) diff --git a/tests/test_em.py b/tests/test_em.py index b6b5fd2..67cdd1b 100644 --- a/tests/test_em.py +++ b/tests/test_em.py @@ -27,12 +27,12 @@ def test_T_interp_no_xesmf(): url = Path(__file__).parent / "data/test_roms.nc" ds = xr.open_dataset(url) - da_out, _ = em.select(da=ds["zeta"], T=0.5) + da_out = em.select(da=ds["zeta"], T=0.5) assert np.allclose(da_out[0, 0], -0.12584045) XESMF_AVAILABLE = em.extract_model.XESMF_AVAILABLE em.extract_model.XESMF_AVAILABLE = False - da_out, _ = em.select(da=ds["zeta"], T=0.5) + da_out = em.select(da=ds["zeta"], T=0.5) assert np.allclose(da_out[0, 0], -0.12584045) em.extract_model.XESMF_AVAILABLE = XESMF_AVAILABLE @@ -42,7 +42,7 @@ def test_Z_interp(): url = Path(__file__).parent / "data/test_hycom.nc" ds = xr.open_dataset(url) - da_out, _ = em.select(da=ds["water_u"], Z=1.0) + da_out = em.select(da=ds["water_u"], Z=1.0, vertical_interp=True) assert np.allclose(da_out[-1, -1], -0.1365) @@ -65,7 +65,9 @@ def test_hor_interp_no_xesmf(): XESMF_AVAILABLE = em.extract_model.XESMF_AVAILABLE em.extract_model.XESMF_AVAILABLE = False with pytest.raises(ModuleNotFoundError): - em.select(da, longitude=longitude, latitude=latitude, T=0.5) + em.select( + da, longitude=longitude, latitude=latitude, T=0.5, horizontal_interp=True + ) em.extract_model.XESMF_AVAILABLE = XESMF_AVAILABLE @@ -95,14 +97,14 @@ def test_sel2d(model): inputs = { da.cf["longitude"].name: lon_comp, da.cf["latitude"].name: lat_comp, - "distances_name": "distance", + "return_info": True, } - da_sel2d = em.sel2d(da, **inputs) + da_sel2d, kwargs_out = em.sel2d(da, **inputs) da_check = da.cf.isel(X=i, Y=j) assert np.allclose(da_sel2d[varname].squeeze(), da_check) # 6371 is radius of earth in km - assert np.allclose(da_sel2d["distance"], np.deg2rad(dlat) * 6371) + assert np.allclose(kwargs_out["distances"], np.deg2rad(dlat) * 6371) @pytest.mark.parametrize("model", models, ids=lambda x: x["name"]) @@ -356,24 +358,22 @@ def test_sel2d_simple_2D(): ) # check distance when ran with exact grid point - ds_out = em.sel2d(ds, lon=0, lat=4, distances_name="distance") - assert np.allclose(ds_out["distance"], 0) + ds_out, kwargs_out = em.sel2d(ds, lon=0, lat=4, return_info=True) + assert np.allclose(kwargs_out["distances"], 0) # check that alternative function call returns exact results - ds_outcf = em.sel2dcf(ds, longitude=0, latitude=4, distances_name="distance") - assert ds_out == ds_outcf + ds_outcf = em.sel2dcf(ds, longitude=0, latitude=4) + assert ds_out.coords == ds_outcf.coords # use mask leaving one valid point to check behavior with mask mask = (ds.cf["longitude"] == 3).astype(int) - ds_out = em.sel2d(ds, lon=0, lat=4, mask=mask, distances_name="distance") + ds_out = em.sel2d(ds, lon=0, lat=4, mask=mask) assert ds_out.lon == 3 assert ds_out.lat == 7 - ds_outcf = em.sel2dcf( - ds, longitude=0, latitude=4, mask=mask, distances_name="distance" - ) - assert ds_out == ds_outcf + ds_outcf = em.sel2dcf(ds, longitude=0, latitude=4, mask=mask) + assert ds_out.coords == ds_outcf.coords # if distance_name=None, no distance returned - ds_out = em.sel2d(ds, lon=0, lat=4, distances_name=None) - assert "distance" not in ds_out.variables + ds_out = em.sel2d(ds, lon=0, lat=4) + assert "distances" not in ds_out.variables diff --git a/tests/test_sel2d.py b/tests/test_sel2d.py index 5f3fb50..cb4402e 100644 --- a/tests/test_sel2d.py +++ b/tests/test_sel2d.py @@ -36,15 +36,15 @@ def test_lon_lat_types(): # Floats da_test = em.sel2d(da, lon_rho=lon, lat_rho=lat).squeeze() - assert np.allclose(da_check, da_test) + assert np.allclose(da_check, da_test.to_array()) # List da_test = em.sel2d(da, lon_rho=[lon], lat_rho=[lat]).squeeze() - assert np.allclose(da_check, da_test) + assert np.allclose(da_check, da_test.to_array()) # Array da_test = em.sel2d(da, lon_rho=np.array([lon]), lat_rho=np.array([lat])).squeeze() - assert np.allclose(da_check, da_test) + assert np.allclose(da_check, da_test.to_array()) def test_2D(): @@ -94,7 +94,7 @@ def test_ll_name_reversal(): da1 = em.sel2d(da, lon_rho=lon, lat_rho=lat).squeeze() da2 = em.sel2d(da, lat_rho=lat, lon_rho=lon).squeeze() - assert np.allclose(da1, da2) + assert np.allclose(da1.to_array(), da2.to_array()) def test_sel_time(): @@ -105,7 +105,7 @@ def test_sel_time(): da_test = em.sel2d(da, lon_rho=lon, lat_rho=lat, ocean_time=0) - assert np.allclose(da_check, da_test) + assert np.allclose(da_check, da_test.to_array()) # Won't run in different input order with pytest.raises(ValueError): @@ -117,14 +117,14 @@ def test_cf_versions(): da_check = em.sel2d(da, lon_rho=lon, lat_rho=lat) da_test = em.sel2dcf(da, longitude=lon, latitude=lat) - assert np.allclose(da_check, da_test) + assert np.allclose(da_check.to_array(), da_test.to_array()) da_test = em.sel2dcf(da, latitude=lat, longitude=lon) - assert np.allclose(da_check, da_test) + assert np.allclose(da_check.to_array(), da_test.to_array()) da_check = em.sel2d(da, lon_rho=lon, lat_rho=lat, ocean_time=0) da_test = em.sel2dcf(da, latitude=lat, longitude=lon, T=0) - assert np.allclose(da_check, da_test) + assert np.allclose(da_check.to_array(), da_test.to_array()) da_test = em.sel2dcf(da, T=0, longitude=lon, latitude=lat) - assert np.allclose(da_check, da_test) + assert np.allclose(da_check.to_array(), da_test.to_array())