-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PR]: Update Regrid2 missing and fill value behaviors to align with CDAT and add unmapped_to_nan
arg for output data
#613
Changes from 3 commits
715b920
a8e7661
85d1c7c
7e900a0
3743e97
5aeee56
7af50a0
e193c1d
dfc3ebc
646a9a0
6719f03
56f18cb
0b01f79
647a786
79a4cd1
fc17e7a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
import numpy as np | ||
import xarray as xr | ||
from dask.array.core import Array | ||
|
||
from xcdat.axis import get_dim_keys | ||
from xcdat.regridder.base import BaseRegridder, _preserve_bounds | ||
|
@@ -78,6 +79,13 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: | |
# Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20 | ||
input_data_var = input_data_var.where(src_mask != 0.0, masked_value) | ||
|
||
nan_replace = input_data_var.encoding.get("_FillValue", None) | ||
|
||
if nan_replace is None: | ||
nan_replace = input_data_var.encoding.get("missing_value", 1e20) | ||
|
||
input_data_var = input_data_var.fillna(nan_replace) | ||
|
||
output_data = _regrid( | ||
input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds | ||
) | ||
|
@@ -106,7 +114,7 @@ def _regrid( | |
lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds) | ||
|
||
# convert to pure numpy | ||
input_data = input_data_var.astype(np.float32).data | ||
input_data = input_data_var.astype(np.float32).values | ||
|
||
y_name, y_index = _get_dimension(input_data_var, "Y") | ||
x_name, x_index = _get_dimension(input_data_var, "X") | ||
|
@@ -208,7 +216,9 @@ def _build_dataset( | |
return output_ds | ||
|
||
|
||
def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: | ||
def _map_latitude( | ||
src: np.ndarray, dst: np.ndarray | ||
) -> Tuple[List[np.ndarray], List[np.ndarray]]: | ||
""" | ||
Map source to destination latitude. | ||
|
||
|
@@ -230,7 +240,7 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: | |
|
||
Returns | ||
------- | ||
Tuple[List, List] | ||
Tuple[List[np.ndarray], List[np.ndarray]] | ||
A tuple of cell mappings and cell weights. | ||
""" | ||
src_south, src_north = _extract_bounds(src) | ||
|
@@ -255,14 +265,25 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: | |
] | ||
|
||
# convert latitude to cell weight (difference of height above/below equator) | ||
weights = [ | ||
(np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1)) | ||
for x, y in bounds | ||
] | ||
weights = _get_latitude_weights(bounds) | ||
|
||
return mapping, weights | ||
|
||
|
||
def _get_latitude_weights( | ||
bounds: List[Tuple[np.ndarray, np.ndarray]] | ||
) -> List[np.ndarray]: | ||
weights = [] | ||
|
||
for x, y in bounds: | ||
cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y)) | ||
cell_weight = cell_weight.reshape((-1, 1)) | ||
|
||
weights.append(cell_weight) | ||
|
||
return weights | ||
|
||
|
||
Comment on lines
+312
to
+325
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I extracted this section of code into its own private function and used an explicit for loop to make it more readable compared to list comprehension (IMO). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also I don't think Dask Arrays support |
||
def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: | ||
""" | ||
Map source to destination longitude. | ||
|
@@ -340,19 +361,19 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: | |
return mapping, weights | ||
|
||
|
||
def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | ||
def _extract_bounds(bounds: np.ndarray | Array) -> Tuple[np.ndarray, np.ndarray]: | ||
""" | ||
Extract lower and upper bounds from an axis. | ||
|
||
Parameters | ||
---------- | ||
bounds : np.ndarray | ||
Dataset containing axis with bounds. | ||
bounds : np.ndarray | dask.core.array.Array | ||
A numpy array or dask array of bounds values. | ||
|
||
Returns | ||
------- | ||
Tuple[np.ndarray, np.ndarray] | ||
A tuple containing the lower and upper bounds for the axis. | ||
A tuple containing the lower and upper bounds for the axis. | ||
""" | ||
if bounds[0, 0] < bounds[0, 1]: | ||
lower = bounds[:, 0] | ||
|
@@ -361,6 +382,15 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
lower = bounds[:, 1] | ||
upper = bounds[:, 0] | ||
|
||
# Make sure to convert the bounds to numpy array beforehand. | ||
# Otherwise the error `ValueError: cannot convert float NaN to integer` | ||
# is raised when calculating cell weights with `_map_longitude` when | ||
# calling `np.sin(np.deg2rad(x))`. | ||
if isinstance(lower, Array): | ||
lower = lower.compute() | ||
if isinstance(upper, Array): | ||
upper = upper.compute() | ||
|
||
tomvothecoder marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return lower.astype(np.float32), upper.astype(np.float32) | ||
|
||
|
||
|
@@ -498,4 +528,4 @@ def _get_bounds_ensure_dtype(ds, axis): | |
if bounds.dtype != np.float32: | ||
bounds = bounds.astype(np.float32) | ||
|
||
return bounds.data | ||
return bounds.values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make sure to use |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Make sure to use
.values
to convert Dask Arrays to.ndarray
. Otherwise we get aTypeError: take() got an unexpected keyword argument 'mode'
downstream withnp.take()
.