Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PR]: Update Regrid2 missing and fill value behaviors to align with CDAT and add unmapped_to_nan arg for output data #613

Merged
merged 16 commits into from
Apr 10, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def test_regrid(self):
assert "time_bnds" in output

@pytest.mark.parametrize(
"name,value,attr_name",
"name,value,_",
[
("periodic", True, "_periodic"),
("extrap_method", "inverse_dist", "_extrap_method"),
Expand All @@ -700,14 +700,15 @@ def test_regrid(self):
("ignore_degenerate", False, "_ignore_degenerate"),
],
)
def test_flags(self, name, value, attr_name):
def test_flags(self, name, value, _):
ds = self.ds.copy()

options = {name: value}

regridder = xesmf.XESMFRegridder(ds, self.new_grid, "bilinear", **options)

assert getattr(regridder, attr_name) == value
assert name in regridder._extra_options
assert regridder._extra_options[name] == value

def test_no_variable(self):
ds = self.ds.copy()
Expand Down
54 changes: 42 additions & 12 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import numpy as np
import xarray as xr
from dask.array.core import Array

from xcdat.axis import get_dim_keys
from xcdat.regridder.base import BaseRegridder, _preserve_bounds
Expand Down Expand Up @@ -78,6 +79,13 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
# Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20
input_data_var = input_data_var.where(src_mask != 0.0, masked_value)

nan_replace = input_data_var.encoding.get("_FillValue", None)

if nan_replace is None:
nan_replace = input_data_var.encoding.get("missing_value", 1e20)

input_data_var = input_data_var.fillna(nan_replace)

output_data = _regrid(
input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds
)
Expand Down Expand Up @@ -106,7 +114,7 @@ def _regrid(
lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds)

# convert to pure numpy
input_data = input_data_var.astype(np.float32).data
input_data = input_data_var.astype(np.float32).values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to use .values to convert Dask Arrays to .ndarray. Otherwise we get a TypeError: take() got an unexpected keyword argument 'mode' downstream with np.take().


y_name, y_index = _get_dimension(input_data_var, "Y")
x_name, x_index = _get_dimension(input_data_var, "X")
Expand Down Expand Up @@ -208,7 +216,9 @@ def _build_dataset(
return output_ds


def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
def _map_latitude(
src: np.ndarray, dst: np.ndarray
) -> Tuple[List[np.ndarray], List[np.ndarray]]:
"""
Map source to destination latitude.

Expand All @@ -230,7 +240,7 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:

Returns
-------
Tuple[List, List]
Tuple[List[np.ndarray], List[np.ndarray]]
A tuple of cell mappings and cell weights.
"""
src_south, src_north = _extract_bounds(src)
Expand All @@ -255,14 +265,25 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
]

# convert latitude to cell weight (difference of height above/below equator)
weights = [
(np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1))
for x, y in bounds
]
weights = _get_latitude_weights(bounds)

return mapping, weights


def _get_latitude_weights(
bounds: List[Tuple[np.ndarray, np.ndarray]]
) -> List[np.ndarray]:
weights = []

for x, y in bounds:
cell_weight = np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))
cell_weight = cell_weight.reshape((-1, 1))

weights.append(cell_weight)

return weights


Comment on lines +312 to +325
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I extracted this section of code into its own private function and used an explicit for loop to make it more readable compared to list comprehension (IMO).

Copy link
Collaborator

@tomvothecoder tomvothecoder Mar 28, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also I don't think Dask Arrays support np.deg2rad and/or np.sin with np.nan values, resulting in ValueError: cannot convert float NaN to integer in #615.

def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
"""
Map source to destination longitude.
Expand Down Expand Up @@ -340,19 +361,19 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
return mapping, weights


def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
def _extract_bounds(bounds: np.ndarray | Array) -> Tuple[np.ndarray, np.ndarray]:
"""
Extract lower and upper bounds from an axis.

Parameters
----------
bounds : np.ndarray
Dataset containing axis with bounds.
bounds : np.ndarray | dask.core.array.Array
A numpy array or dask array of bounds values.

Returns
-------
Tuple[np.ndarray, np.ndarray]
A tuple containing the lower and upper bounds for the axis.
A tuple containing the lower and upper bounds for the axis.
"""
if bounds[0, 0] < bounds[0, 1]:
lower = bounds[:, 0]
Expand All @@ -361,6 +382,15 @@ def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
lower = bounds[:, 1]
upper = bounds[:, 0]

# Make sure to convert the bounds to numpy array beforehand.
# Otherwise the error `ValueError: cannot convert float NaN to integer`
# is raised when calculating cell weights with `_map_longitude` when
# calling `np.sin(np.deg2rad(x))`.
if isinstance(lower, Array):
lower = lower.compute()
if isinstance(upper, Array):
upper = upper.compute()

tomvothecoder marked this conversation as resolved.
Show resolved Hide resolved
return lower.astype(np.float32), upper.astype(np.float32)


Expand Down Expand Up @@ -498,4 +528,4 @@ def _get_bounds_ensure_dtype(ds, axis):
if bounds.dtype != np.float32:
bounds = bounds.astype(np.float32)

return bounds.data
return bounds.values
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make sure to use .values over .data to convert Dask Arrays to np.ndarray

24 changes: 14 additions & 10 deletions xcdat/regridder/xesmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def __init__(
extrap_dist_exponent: Optional[float] = None,
extrap_num_src_pnts: Optional[int] = None,
ignore_degenerate: bool = True,
unmapped_to_nan: bool = True,
**options: Any,
):
"""Extension of ``xESMF`` regridder.
Expand Down Expand Up @@ -74,6 +75,8 @@ def __init__(

This only applies to "conservative" and "conservative_normed"
regridding methods.
unmapped_to_nan : bool
Sets values of unmapped points to `np.nan` instead of 0 (ESMF default).
**options : Any
Additional arguments passed to the underlying ``xesmf.XESMFRegridder``
constructor.
Expand Down Expand Up @@ -126,11 +129,17 @@ def __init__(
)

self._method = method
self._periodic = periodic
self._extrap_method = extrap_method
self._extrap_dist_exponent = extrap_dist_exponent
self._extrap_num_src_pnts = extrap_num_src_pnts
self._ignore_degenerate = ignore_degenerate

# Re-pack xesmf arguments, broken out for validation/documentation
options.update(
periodic=periodic,
extrap_method=extrap_method,
extrap_dist_exponent=extrap_dist_exponent,
extrap_num_src_pnts=extrap_num_src_pnts,
ignore_degenerate=ignore_degenerate,
unmapped_to_nan=unmapped_to_nan,
)

self._extra_options = options

def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
Expand All @@ -150,11 +159,6 @@ def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset:
self._input_grid,
self._output_grid,
method=self._method,
periodic=self._periodic,
extrap_method=self._extrap_method,
extrap_dist_exponent=self._extrap_dist_exponent,
extrap_num_src_pnts=self._extrap_num_src_pnts,
ignore_degenerate=self._ignore_degenerate,
**self._extra_options,
)

Expand Down
Loading