Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR]: Update Regrid2 missing and fill value behaviors to align with CDAT and add unmapped_to_nan arg for output data #613

Merged
merged 16 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 11 additions & 7 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,12 +496,15 @@ def test_regrid_input_mask(self):

output_data = regridder.horizontal("ts", self.coarse_2d_ds)

# replace nan with 1e20 as np.nan != np.nan
output_data = output_data.fillna(1e20)

expected_output = np.array(
[
[0.0] * 4,
[0.70710677] * 4,
[0.70710677] * 4,
[0.0] * 4,
[1e20] * 4,
[1.0] * 4,
[1.0] * 4,
[1e20] * 4,
],
dtype=np.float32,
)
Expand Down Expand Up @@ -690,7 +693,7 @@ def test_regrid(self):
assert "time_bnds" in output

@pytest.mark.parametrize(
"name,value,attr_name",
"name,value,_",
[
("periodic", True, "_periodic"),
("extrap_method", "inverse_dist", "_extrap_method"),
Expand All @@ -700,14 +703,15 @@ def test_regrid(self):
("ignore_degenerate", False, "_ignore_degenerate"),
],
)
def test_flags(self, name, value, attr_name):
def test_flags(self, name, value, _):
ds = self.ds.copy()

options = {name: value}

regridder = xesmf.XESMFRegridder(ds, self.new_grid, "bilinear", **options)

assert getattr(regridder, attr_name) == value
assert name in regridder._extra_options
assert regridder._extra_options[name] == value

def test_no_variable(self):
ds = self.ds.copy()
Expand Down
89 changes: 66 additions & 23 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, List, Tuple
from typing import Any, List, Optional, Tuple

import numpy as np
import xarray as xr
Expand Down Expand Up @@ -66,20 +66,30 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y")
dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X")

src_mask = self._input_grid.get("mask", None)
src_mask_da = self._input_grid.get("mask", None)

# apply source mask to input data
if src_mask is not None:
masked_value = self._input_grid.attrs.get("_FillValue", None)
# DataArray to np.ndarray, handle error when None
try:
src_mask = src_mask_da.values # type: ignore
except AttributeError:
src_mask = None

nan_replace = input_data_var.encoding.get("_FillValue", None)

if masked_value is None:
masked_value = self._input_grid.attrs.get("missing_value", 0.0)
if nan_replace is None:
nan_replace = input_data_var.encoding.get("missing_value", 1e20)

# Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20
input_data_var = input_data_var.where(src_mask != 0.0, masked_value)
# exclude alternative of NaN values if there are any
input_data_var = input_data_var.where(input_data_var != nan_replace)

# horizontal regrid
output_data = _regrid(
input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds
input_data_var,
src_lat_bnds,
src_lon_bnds,
dst_lat_bnds,
dst_lon_bnds,
src_mask,
)

output_ds = _build_dataset(
Expand All @@ -101,7 +111,12 @@ def _regrid(
src_lon_bnds: np.ndarray,
dst_lat_bnds: np.ndarray,
dst_lon_bnds: np.ndarray,
src_mask: Optional[np.ndarray],
omitted=None,
) -> np.ndarray:
if omitted is None:
omitted = np.nan

lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds)
lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds)

Expand All @@ -114,6 +129,11 @@ def _regrid(
y_length = len(lat_mapping)
x_length = len(lon_mapping)

if src_mask is None:
input_data_shape = input_data.shape

src_mask = np.ones((input_data_shape[y_index], input_data_shape[x_index]))

other_dims = {
x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name)
}
Expand All @@ -127,13 +147,20 @@ def _regrid(

# TODO: need to optimize further, investigate using ufuncs and dask arrays
# TODO: how common is lon by lat data? may need to reshape
# import pdb; pdb.set_trace()
for y in range(y_length):
tomvothecoder marked this conversation as resolved.
Show resolved Hide resolved
y_seg = np.take(input_data, lat_mapping[y], axis=y_index)
y_mask_seg = np.take(src_mask, lat_mapping[y], axis=0)

for x in range(x_length):
x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap")
x_mask_seg = np.take(y_mask_seg, lon_mapping[x], axis=1, mode="wrap")

cell_weights = np.multiply(
np.dot(lat_weights[y], lon_weights[x]), x_mask_seg
)

cell_weight = np.dot(lat_weights[y], lon_weights[x])
cell_weight = np.sum(cell_weights)

output_seg_index = y * x_length + x

Expand All @@ -144,23 +171,26 @@ def _regrid(
if is_2d:
output_data[output_seg_index] = np.divide(
np.sum(
np.multiply(x_seg, cell_weight),
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
np.sum(cell_weight),
cell_weight,
)
else:
output_seg = output_data[output_seg_index]

np.divide(
np.sum(
np.multiply(x_seg, cell_weight),
np.multiply(x_seg, cell_weights),
axis=(y_index, x_index),
),
np.sum(cell_weight),
cell_weight,
out=output_seg,
)

if cell_weight <= 0.0:
output_data[output_seg_index] = omitted

output_data_shape = [y_length, x_length] + other_sizes

output_data = output_data.reshape(output_data_shape)
Expand Down Expand Up @@ -208,7 +238,9 @@ def _build_dataset(
return output_ds


def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Map source to destination latitude.

Expand All @@ -230,7 +262,7 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:

Returns
-------
Tuple[List, List]
Tuple[List[np.ndarray], List[np.ndarray]]
A tuple of cell mappings and cell weights.
"""
src_south, src_north = _extract_bounds(src)
Expand All @@ -255,14 +287,25 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
]

# convert latitude to cell weight (difference of height above/below equator)
weights = [
(np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1))
for x, y in bounds
]
weights = _get_latitude_weights(bounds)

return mapping, weights


def _get_latitude_weights(
bounds: List[Tuple[np.ndarray, np.ndarray]]
) -> List[np.ndarray]:
weights = []

for x, y in bounds:
cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))
cell_weight = cell_weight.reshape((-1, 1))

weights.append(cell_weight)

return weights


def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
Comment on lines +312 to +325
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted this section of code into its own private function and used an explicit for loop to make it more readable compared to list comprehension (IMO).

Copy link
Collaborator

@tomvothecoder tomvothecoder Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't think Dask Arrays support np.deg2rad and/or np.sin with np.nan values, resulting in ValueError: cannot convert float NaN to integer in #615.

"""
Map source to destination longitude.
Expand Down Expand Up @@ -347,12 +390,12 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
Parameters
----------
bounds : np.ndarray
Dataset containing axis with bounds.
A numpy array of bounds values.

Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing the lower and upper bounds for the axis.
A tuple containing the lower and upper bounds for the axis.
"""
if bounds[0, 0] < bounds[0, 1]:
lower = bounds[:, 0]
Expand Down
24 changes: 14 additions & 10 deletions xcdat/regridder/xesmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
extrap_dist_exponent: Optional[float] = None,
extrap_num_src_pnts: Optional[int] = None,
ignore_degenerate: bool = True,
unmapped_to_nan: bool = True,
**options: Any,
):
"""Extension of ``xESMF`` regridder.
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(

This only applies to "conservative" and "conservative_normed"
regridding methods.
unmapped_to_nan : bool
Sets values of unmapped points to `np.nan` instead of 0 (ESMF default).
**options : Any
Additional arguments passed to the underlying ``xesmf.XESMFRegridder``
constructor.
Expand Down Expand Up @@ -126,11 +129,17 @@ def __init__(
)

self._method = method
self._periodic = periodic
self._extrap_method = extrap_method
self._extrap_dist_exponent = extrap_dist_exponent
self._extrap_num_src_pnts = extrap_num_src_pnts
self._ignore_degenerate = ignore_degenerate

# Re-pack xesmf arguments, broken out for validation/documentation
options.update(
periodic=periodic,
extrap_method=extrap_method,
extrap_dist_exponent=extrap_dist_exponent,
extrap_num_src_pnts=extrap_num_src_pnts,
ignore_degenerate=ignore_degenerate,
unmapped_to_nan=unmapped_to_nan,
)

self._extra_options = options

def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
Expand All @@ -150,11 +159,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
self._input_grid,
self._output_grid,
method=self._method,
periodic=self._periodic,
extrap_method=self._extrap_method,
extrap_dist_exponent=self._extrap_dist_exponent,
extrap_num_src_pnts=self._extrap_num_src_pnts,
ignore_degenerate=self._ignore_degenerate,
**self._extra_options,
)

Expand Down
Loading