diff --git a/tests/test_regrid.py b/tests/test_regrid.py index 839e78cb..13de350b 100644 --- a/tests/test_regrid.py +++ b/tests/test_regrid.py @@ -384,29 +384,28 @@ def test_vertical_placeholder(self): with pytest.raises(NotImplementedError, match=""): regridder.vertical("so", ds) - def test_missing_dimension(self): - ds = fixtures.generate_dataset( - decode_times=True, cf_compliant=False, has_bounds=True - ) - - del ds.lat.attrs["axis"] + @pytest.mark.filterwarnings("ignore:.*invalid value.*divide.*:RuntimeWarning") + def test_output_bounds(self): + ds = self.coarse_3d_ds output_grid = grid.create_gaussian_grid(32) regridder = regrid2.Regrid2Regridder(ds, output_grid) - with pytest.raises( - RuntimeError, - match="Could not find axis 'lat', ensure 'lat' exists and the attributes are correct.", - ): - regridder.horizontal("ts", ds) + output_ds = regridder.horizontal("ts", ds) + + assert "lat_bnds" in output_ds + assert "lon_bnds" in output_ds + assert "time_bnds" in output_ds @pytest.mark.filterwarnings("ignore:.*invalid value.*divide.*:RuntimeWarning") - def test_output_bounds(self): + def test_output_bounds_missing_temporal(self): ds = fixtures.generate_dataset( decode_times=True, cf_compliant=False, has_bounds=True ) + ds = self.coarse_3d_ds.drop("time_bnds") + output_grid = grid.create_gaussian_grid(32) regridder = regrid2.Regrid2Regridder(ds, output_grid) @@ -415,7 +414,7 @@ def test_output_bounds(self): assert "lat_bnds" in output_ds assert "lon_bnds" in output_ds - assert "time_bnds" in output_ds + assert "time_bnds" not in output_ds @pytest.mark.parametrize( "src,dst,expected_west,expected_east,expected_shift", @@ -499,45 +498,16 @@ def test_regrid_input_mask(self): expected_output = np.array( [ - [0.0, 0.0, 0.0, 0.0], - [0.70710677, 0.70710677, 0.70710677, 0.70710677], - [0.70710677, 0.70710677, 0.70710677, 0.70710677], - [0.0, 0.0, 0.0, 0.0], + [0.0] * 4, + [0.70710677] * 4, + [0.70710677] * 4, + [0.0] * 4, ], dtype=np.float32, ) assert np.all(output_data.ts.values == expected_output) - def test_regrid_output_mask(self): - output_mask = [ - [0, 0, 0, 0], - [1, 1, 1, 1], - [1, 1, 1, 1], - [0, 0, 0, 0], - ] - - self.fine_2d_ds["mask"] = (("lat", "lon"), output_mask) - - regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds) - - output_data = regridder.horizontal("ts", self.coarse_2d_ds) - - expected_output = np.array( - [ - [1.0, 1.0, 1.0, 1.0], - [1e20, 1e20, 1e20, 1e20], - [1e20, 1e20, 1e20, 1e20], - [1.0, 1.0, 1.0, 1.0], - ], - dtype=np.float32, - ) - - # need to replace nans since nan != nan - output_data["ts"] = output_data.ts.fillna(1e20) - - assert np.all(output_data.ts.values == expected_output) - def test_preserve_attrs(self): regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds) @@ -547,7 +517,7 @@ def test_preserve_attrs(self): assert output_data["ts"].attrs == self.da_attrs for x in output_data.coords: - assert output_data[x].attrs == self.coarse_2d_ds[x].attrs + assert output_data[x].attrs == self.coarse_2d_ds[x].attrs, f"{x}" def test_regrid_2d(self): regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds) @@ -582,7 +552,7 @@ def test_regrid_4d(self): def test_map_longitude_coarse_to_fine(self): mapping, weights = regrid2._map_longitude( - self.coarse_lon_bnds, self.fine_lon_bnds + self.coarse_lon_bnds.values, self.fine_lon_bnds.values ) expected_mapping = [ @@ -604,7 +574,7 @@ def test_map_longitude_coarse_to_fine(self): def test_map_longitude_fine_to_coarse(self): mapping, weights = regrid2._map_longitude( - self.fine_lon_bnds, self.coarse_lon_bnds + self.fine_lon_bnds.values, self.coarse_lon_bnds.values ) expected_mapping = [ @@ -619,7 +589,7 @@ def test_map_longitude_fine_to_coarse(self): def test_map_latitude_coarse_to_fine(self): mapping, weights = regrid2._map_latitude( - self.coarse_lat_bnds, self.fine_lat_bnds + self.coarse_lat_bnds.values, self.fine_lat_bnds.values ) expected_mapping = [ @@ -648,7 +618,7 @@ def test_map_latitude_coarse_to_fine(self): def test_map_latitude_fine_to_coarse(self): mapping, weights = regrid2._map_latitude( - self.fine_lat_bnds, self.coarse_lat_bnds + self.fine_lat_bnds.values, self.coarse_lat_bnds.values ) expected_mapping = [ @@ -684,6 +654,12 @@ def test_reversed_extract_bounds(self): assert north.shape == (3,) assert north[0], north[-1] == (60, 90) + def test_get_bounds_ensure_dtype(self): + del self.coarse_2d_ds.lon.attrs["bounds"] + + with pytest.raises(RuntimeError): + regrid2._get_bounds_ensure_dtype(self.coarse_2d_ds, "X") + class TestXESMFRegridder: @pytest.fixture(autouse=True) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index bf32956d..602a16d7 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -1,8 +1,9 @@ -from typing import Any, Dict, List, Tuple +from typing import Any, List, Tuple import numpy as np import xarray as xr +from xcdat.axis import get_dim_keys from xcdat.regridder.base import BaseRegridder, _preserve_bounds @@ -46,260 +47,186 @@ def __init__(self, input_grid: xr.Dataset, output_grid: xr.Dataset, **options: A """ super().__init__(input_grid, output_grid, **options) - self._src_lat = self._input_grid.bounds.get_bounds("Y") - self._src_lon = self._input_grid.bounds.get_bounds("X") - - self._dst_lat = self._output_grid.bounds.get_bounds("Y") - self._dst_lon = self._output_grid.bounds.get_bounds("X") - - self._lat_mapping: Any = None - self._lon_mapping: Any = None - - self._lat_weights: Any = None - self._lon_weights: Any = None - def vertical(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """Placeholder for base class.""" raise NotImplementedError() def horizontal(self, data_var: str, ds: xr.Dataset) -> xr.Dataset: """See documentation in :py:func:`xcdat.regridder.regrid2.Regrid2Regridder`""" - input_data_var = ds.get(data_var, None) - - if input_data_var is None: + try: + input_data_var = ds[data_var] + except KeyError: raise KeyError( - f"The data variable '{data_var}' does not exist in the dataset." - ) + f"The data variable {data_var!r} does not exist in the dataset." + ) from None - # Do initial mapping between src/dst latitude and longitude. - if self._lat_mapping is None and self._lat_weights is None: - self._lat_mapping, self._lat_weights = _map_latitude( - self._src_lat, self._dst_lat - ) + src_lat_bnds = _get_bounds_ensure_dtype(self._input_grid, "Y") + src_lon_bnds = _get_bounds_ensure_dtype(self._input_grid, "X") - if self._lon_mapping is None and self._lon_weights is None: - self._lon_mapping, self._lon_weights = _map_longitude( - self._src_lon, self._dst_lon - ) + dst_lat_bnds = _get_bounds_ensure_dtype(self._output_grid, "Y") + dst_lon_bnds = _get_bounds_ensure_dtype(self._output_grid, "X") src_mask = self._input_grid.get("mask", None) # apply source mask to input data if src_mask is not None: - input_data_var = input_data_var.where(src_mask != 0.0) - - # operate on pure numpy - input_data = input_data_var.values - - axis_variable_name_map = {x: y[0] for x, y in input_data_var.cf.axes.items()} - - output_axis_sizes = self._output_axis_sizes(input_data_var) + masked_value = self._input_grid.attrs.get("_FillValue", None) - ordered_axis_names = list(output_axis_sizes) + if masked_value is None: + masked_value = self._input_grid.attrs.get("missing_value", 0.0) - output_data = self._regrid(input_data, output_axis_sizes, ordered_axis_names) + # Xarray defaults to masking with np.nan, CDAT masked with _FillValue or missing_value which defaults to 1e20 + input_data_var = input_data_var.where(src_mask != 0.0, masked_value) - output_ds = self._create_output_dataset( - ds, data_var, output_data, axis_variable_name_map, ordered_axis_names + output_data = _regrid( + input_data_var, src_lat_bnds, src_lon_bnds, dst_lat_bnds, dst_lon_bnds ) - dst_mask = self._output_grid.get("mask", None) - - if dst_mask is not None: - output_ds[data_var] = output_ds[data_var].where(dst_mask == 0.0) - - output_ds = _preserve_bounds(ds, self._output_grid, output_ds, ["X", "Y"]) + output_ds = _build_dataset( + ds, + data_var, + output_data, + dst_lat_bnds, + dst_lon_bnds, + self._input_grid, + self._output_grid, + ) return output_ds - def _output_axis_sizes(self, da: xr.DataArray) -> Dict[str, int]: - """Maps axes to output array sizes. - Parameters - ---------- - da : xr.DataArray - Data array containing variable to be regridded. +def _regrid( + input_data_var: xr.DataArray, + src_lat_bnds: np.ndarray, + src_lon_bnds: np.ndarray, + dst_lat_bnds: np.ndarray, + dst_lon_bnds: np.ndarray, +) -> np.ndarray: + lat_mapping, lat_weights = _map_latitude(src_lat_bnds, dst_lat_bnds) + lon_mapping, lon_weights = _map_longitude(src_lon_bnds, dst_lon_bnds) - Returns - ------- - Dict - Mapping of axis name e.g. ("X", "Y", etc) to output sizes. - """ - output_sizes = {} + # convert to pure numpy + input_data = input_data_var.astype(np.float32).data - axis_name_map = {y[0]: x for x, y in da.cf.axes.items()} + y_name, y_index = _get_dimension(input_data_var, "Y") + x_name, x_index = _get_dimension(input_data_var, "X") - for standard_name in da.sizes.keys(): - try: - axis_name = axis_name_map[standard_name] - except KeyError: - raise RuntimeError( - f"Could not find axis {standard_name!r}, ensure {standard_name!r} " - "exists and the attributes are correct." - ) + y_length = len(lat_mapping) + x_length = len(lon_mapping) - if standard_name in self._output_grid: - output_sizes[axis_name] = self._output_grid.sizes[standard_name] - else: - output_sizes[axis_name] = da.sizes[standard_name] + other_dims = { + x: y for x, y in input_data_var.sizes.items() if x not in (y_name, x_name) + } + other_sizes = list(other_dims.values()) - return output_sizes + data_shape = [y_length * x_length] + other_sizes + # output data is always float32 in original code + output_data = np.zeros(data_shape, dtype=np.float32) - def _regrid( - self, - input_data: np.ndarray, - axis_sizes: Dict[str, int], - ordered_axis_names: List[str], - ) -> np.ndarray: - """Applies regridding to input data. + is_2d = input_data_var.ndim <= 2 - Parameters - ---------- - input_data : np.ndarray - Input multi-dimensional array on source grid. - axis_sizes : Dict[str, int] - Mapping of axis name e.g. ("X", "Y", etc) to output sizes. - ordered_axis_names : List[str] - List of axis name in order of dimensions of ``input_data``. - - Returns - ------- - np.ndarray - Multi-dimensional array on destination grid. - """ - input_lat_index = ordered_axis_names.index("Y") - - input_lon_index = ordered_axis_names.index("X") - - output_shape = [axis_sizes[x] for x in ordered_axis_names] - - output_data = np.zeros(output_shape, dtype=np.float32) + # TODO: need to optimize further, investigate using ufuncs and dask arrays + # TODO: how common is lon by lat data? may need to reshape + for y in range(y_length): + y_seg = np.take(input_data, lat_mapping[y], axis=y_index) - base_put_index = self._base_put_indexes(axis_sizes) + for x in range(x_length): + x_seg = np.take(y_seg, lon_mapping[x], axis=x_index, mode="wrap") - for lat_index, lat_map in enumerate(self._lat_mapping): - lat_weight = self._lat_weights[lat_index] + cell_weight = np.dot(lat_weights[y], lon_weights[x]) - input_lat_segment = np.take(input_data, lat_map, axis=input_lat_index) + output_seg_index = y * x_length + x - for lon_index, lon_map in enumerate(self._lon_mapping): - lon_weight = self._lon_weights[lon_index] - - dot_weight = np.dot(lat_weight, lon_weight) - - cell_weight = np.sum(dot_weight) - - input_lon_segment = np.take( - input_lat_segment, lon_map, axis=input_lon_index + # using the `out` argument is more performant, places data directly into + # array memory rather than allocating a new variable. wasn't working for + # single element output, needs further investigation as we may not need + # branch + if is_2d: + output_data[output_seg_index] = np.divide( + np.sum( + np.multiply(x_seg, cell_weight), + axis=(y_index, x_index), + ), + np.sum(cell_weight), ) - - data = ( - np.nansum( - np.multiply(input_lon_segment, dot_weight), - axis=(input_lat_index, input_lon_index), - ) - / cell_weight + else: + output_seg = output_data[output_seg_index] + + np.divide( + np.sum( + np.multiply(x_seg, cell_weight), + axis=(y_index, x_index), + ), + np.sum(cell_weight), + out=output_seg, ) - # This only handles lat by lon and not lon by lat - put_index = base_put_index + ((lat_index * axis_sizes["X"]) + lon_index) + output_data_shape = [y_length, x_length] + other_sizes - np.put(output_data, put_index, data) + output_data = output_data.reshape(output_data_shape) - return output_data + output_order = [x + 2 for x in range(input_data_var.ndim - 2)] + [0, 1] - def _base_put_indexes(self, axis_sizes: Dict[str, int]) -> np.ndarray: - """Calculates the base indexes to place cell (0, 0). + output_data = output_data.transpose(output_order) - Example: - For a 3D array (time, lat, lon) with the shape (2, 2, 2) the offsets to - place cell (0, 0) in each time step would be [0, 4]. + return output_data.astype(np.float32) - For a 4D array (time, plev, lat, lon) with shape (2, 2, 2, 2) the offsets - to place cell (0, 0) in each time step would be [0, 4, 8, 16]. - Parameters - ---------- - axis_sizes : Dict[str, int] - Mapping of axis name e.g. ("X", "Y", etc) to output sizes. +def _build_dataset( + ds: xr.Dataset, + data_var: str, + output_data: np.ndarray, + dst_lat_bnds, + dst_lon_bnds, + input_grid: xr.Dataset, + output_grid: xr.Dataset, +) -> xr.Dataset: + input_data_var = ds[data_var] - Returns - ------- - np.ndarray - Array containing the base indexes to be used in np.put operations. - """ - extra_dims = set(axis_sizes) - set(["X", "Y"]) + output_coords: dict[str, xr.DataArray] = {} + output_data_vars: dict[str, xr.DataArray] = {} - number_of_offsets = np.multiply.reduce([axis_sizes[x] for x in extra_dims]) - - offset = np.multiply.reduce( - [axis_sizes[x] for x in extra_dims ^ set(axis_sizes)] - ) + dims = list(input_data_var.dims) - return (np.arange(number_of_offsets) * offset).astype(np.int64) + output_da = xr.DataArray( + output_data, + dims=dims, + coords=output_coords, + attrs=ds[data_var].attrs.copy(), + name=data_var, + ) - def _create_output_dataset( - self, - input_ds: xr.Dataset, - data_var: str, - output_data: np.ndarray, - axis_variable_name_map: Dict[str, str], - ordered_axis_names: List[str], - ) -> xr.Dataset: - """ - Creates the output Dataset containing the new variable on the destination grid. + output_data_vars[data_var] = output_da - Parameters - ---------- - input_ds : xr.Dataset - Input dataset containing coordinates and bounds for unmodified axes. - data_var : str - The name of the regridded variable. - output_data : np.ndarray - Output data array. - axis_variable_name_map : Dict[str, str] - Map of axis name e.g. ("X", "Y", etc) to variable name e.g. ("lon", "lat", etc). - ordered_axis_names : List[str] - List of axis names in the order observed for ``output_data``. - - Returns - ------- - xr.Dataset - Dataset containing the variable on the destination grid. - """ - variable_axis_name_map = {y: x for x, y in axis_variable_name_map.items()} + output_ds = xr.Dataset( + output_data_vars, + attrs=input_grid.attrs.copy(), + ) - coords = {} + output_ds = _preserve_bounds(ds, output_grid, output_ds, ["X", "Y"]) - # Grab coords and bounds from appropriate dataset. - for variable_name, axis_name in variable_axis_name_map.items(): - if axis_name in ["X", "Y"]: - coords[variable_name] = self._output_grid[variable_name].copy() - else: - coords[variable_name] = input_ds[variable_name].copy() + return output_ds - output_da = xr.DataArray( - output_data, - dims=[axis_variable_name_map[x] for x in ordered_axis_names], - coords=coords, - attrs=input_ds[data_var].attrs.copy(), - ) - data_vars = {data_var: output_da} +def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: + """ + Map source to destination latitude. - return xr.Dataset(data_vars, attrs=input_ds.attrs.copy()) + Source cells are grouped by the contribution to each output cell. + Source cells have new boundaries calculated by finding minimum northern + and maximum southern boundary between each source cell and the destination + cell it contributes to. -def _map_latitude(src: xr.DataArray, dst: xr.DataArray) -> Tuple[List, List]: - """ - Map source to destination latitude. + The source cell weights are calculated by taking the difference of sin's + between these new boundary pairs. Parameters ---------- - src : xr.DataArray - DataArray containing the source latitude bounds. - dst : xr.DataArray - DataArray containing the destination latitude bounds. + src : np.ndarray + Array containing the source latitude bounds. + dst : np.ndarray + Array containing the destination latitude bounds. Returns ------- @@ -309,36 +236,55 @@ def _map_latitude(src: xr.DataArray, dst: xr.DataArray) -> Tuple[List, List]: src_south, src_north = _extract_bounds(src) dst_south, dst_north = _extract_bounds(dst) - mapping = [] - weights = [] - - for i in range(dst.shape[0]): - contrib = np.where( - np.logical_and(src_south < dst_north[i], src_north > dst_south[i]) - )[0] + dst_length = dst_south.shape[0] + + # finds contributing source cells for each destination cell based on bounds values + # output is a list of lists containing the contributing cell indexes + # e.g. let src_south be [90, 45, 0, -45], source_north be [45, 0, -45, -90], + # dst_north[x] be 70, and dst_south[x] be -70 then the result would be [[1, 2]] + mapping = [ + np.where(np.logical_and(src_south < dst_north[x], src_north > dst_south[x]))[0] + for x in range(dst_length) + ] + + # finds minimum and maximum bounds for each output cell, considers source and + # destination bounds for each cell + bounds = [ + (np.minimum(dst_north[x], src_north[y]), np.maximum(dst_south[x], src_south[y])) + for x, y in enumerate(mapping) + ] + + # convert latitude to cell weight (difference of height above/below equator) + weights = [ + (np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1)) + for x, y in bounds + ] - mapping.append(contrib) + return mapping, weights - north_bounds = np.minimum(dst_north[i], src_north[contrib]) - south_bounds = np.maximum(dst_south[i], src_south[contrib]) - weight = np.sin(np.deg2rad(north_bounds)) - np.sin(np.deg2rad(south_bounds)) +def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: + """ + Map source to destination longitude. - weights.append(weight.values.reshape(contrib.shape[0], 1)) + Source boundaries are aligned to the most western destination cell. - return mapping, weights + Source cells are grouped by the contribution to each output cell. + The source cell weights are calculated by find the difference of the + following min/max for each input cell. Minimum of eastern source bounds + and the eastern bounds of the destination cell it contributes to. Maximum + of western source bounds and the western bounds of the destination cell + it contributes to. -def _map_longitude(src: xr.DataArray, dst: xr.DataArray) -> Tuple[List, List]: - """ - Map source to destination longitude. + These weights are then shifted to align with the destination longitude. Parameters ---------- - src : xr.DataArray - DataArray containing source longitude bounds. - dst : xr.DataArray - DataArray containing destination longitude bounds. + src : np.ndarray + Array containing source longitude bounds. + dst : np.ndarray + Array containing destination longitude bounds. Returns ------- @@ -348,50 +294,64 @@ def _map_longitude(src: xr.DataArray, dst: xr.DataArray) -> Tuple[List, List]: src_west, src_east = _extract_bounds(src) dst_west, dst_east = _extract_bounds(dst) + # align source and destination longitude shifted_src_west, shifted_src_east, shift = _align_axis( - src_west, src_east, dst_west + src_west, + src_east, + dst_west, ) - mapping = [] - weights = [] src_length = src_west.shape[0] + dst_length = dst_west.shape[0] - for i in range(dst_west.shape[0]): - contrib = np.where( + # finds contributing source cells for each destination cell based on bounds values + # output is a list of lists containing the contributing cell indexes + mapping = [ + np.where( np.logical_and( - shifted_src_west < dst_east[i], shifted_src_east > dst_west[i] + shifted_src_west < dst_east[x], shifted_src_east > dst_west[x] ) )[0] - - weight = np.minimum(dst_east[i], shifted_src_east[contrib]) - np.maximum( - dst_west[i], shifted_src_west[contrib] - ) - - weights.append(weight.values.reshape(1, contrib.shape[0])) - - contrib += shift - - wrapped = np.where(contrib > src_length - 1) - - contrib[wrapped] -= src_length - - mapping.append(contrib) + for x in range(dst_length) + ] + + # weights are just the difference between minimum and maximum of contributing bounds + # for each destination cell + weights = [ + ( + np.minimum(dst_east[x], shifted_src_east[y]) + - np.maximum(dst_west[x], shifted_src_west[y]) + ).reshape((1, -1)) + for x, y in enumerate(mapping) + ] + + # need to adjust the source contributing indexes by the shift required to align + # source and destination longitude + for x in range(len(mapping)): + # shift the mapping indexes by the shift used to determine the weights + mapping[x] += shift + + # find the contributing indexes that need to be wrapped + wrapped = np.where(mapping[x] > src_length - 1)[0] + + # wrap the contributing index as all indexes must be Tuple[xr.DataArray, xr.DataArray]: +def _extract_bounds(bounds: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ - Extract lower and upper bounds from an axis. + Extract lower and upper bounds from an axis. - Parameters - ---------- - bounds : xr.DataArray - Dataset containing axis with bounds. + Parameters + ---------- + bounds : np.ndarray + Dataset containing axis with bounds. - Returns - ------- - Tuple[xr.DataArray, xr.DataArray] + Returns + ------- + Tuple[np.ndarray, np.ndarray] A tuple containing the lower and upper bounds for the axis. """ if bounds[0, 0] < bounds[0, 1]: @@ -401,43 +361,50 @@ def _extract_bounds(bounds: xr.DataArray) -> Tuple[xr.DataArray, xr.DataArray]: lower = bounds[:, 1] upper = bounds[:, 0] - return lower, upper + return lower.astype(np.float32), upper.astype(np.float32) def _align_axis( - src_west: xr.DataArray, src_east: xr.DataArray, dst_west: xr.DataArray -) -> Tuple[xr.DataArray, xr.DataArray, int]: + src_west: np.ndarray, + src_east: np.ndarray, + dst_west: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray, int]: """ - Aligns a longitudinal source axis to the destination axis. + Aligns a source and destination longitude axis. Parameters ---------- - src_west : xr.DataArray - DataArray containing the western source bounds. - src_east : xr.DataArray - DataArray containing the eastern source bounds. - dst_west : xr.DataArray - DataArray containing the western destination bounds. + src_west : np.ndarray + Array containing the western source bounds. + src_east : np.ndarray + Array containing the eastern source bounds. + dst_west : np.ndarray + Array containing the western destination bounds. Returns ------- - Tuple[xr.DataArray, xr.DataArray, int] + Tuple[np.ndarray, np.ndarray, int] A tuple containing the shifted western source bounds, the shifted eastern source bounds, and the number of places shifted to align axis. """ + # find smallest western bounds west_most = np.minimum(dst_west[0], dst_west[-1]) + # find cell index required to align bounds alignment_index = _vpertub((west_most - src_west[-1]) / 360.0) - if src_west[0] < src_west[-1]: - alignment_index += 1 - else: - alignment_index -= 1 + # shift index depending on first/last source bounds + alignment_index = ( + alignment_index + 1 if src_west[0] < src_west[-1] else alignment_index - 1 + ) - src_alignment_index = np.where( - _vpertub((west_most - src_west) / 360.0) != alignment_index - )[0][0] + # find relative indexes for each source cell to the destinations most western cell + relative_postition = _vpertub((west_most - src_west) / 360.0) + # find all index values that are not the alignment index + src_alignment_index = np.where(relative_postition != alignment_index)[0][0] + + # determine the shift factor required to align source and destination bounds if src_west[0] < src_west[-1]: if west_most == src_west[src_alignment_index]: shift = src_alignment_index @@ -451,20 +418,26 @@ def _align_axis( src_length = src_west.shape[0] + # shift the source index values shifted_indexes = np.arange(src_length + 1) + shift + # find index values that need to be shift to be within 0 - src_length wrapped = np.where(shifted_indexes > src_length - 1) + # shift the indexes shifted_indexes[wrapped] -= src_length - shifted_src_west = src_west[shifted_indexes] + 360.0 * _vpertub( - (west_most - src_west[shifted_indexes]) / 360.0 + # reorder src_west and add portion to align + shifted_src_west = ( + src_west[shifted_indexes] + 360.0 * relative_postition[shifted_indexes] ) - shifted_src_east = src_east[shifted_indexes] + 360.0 * _vpertub( - (west_most - src_west[shifted_indexes]) / 360.0 + # reorder src_east and add portion to align + shifted_src_east = ( + src_east[shifted_indexes] + 360.0 * relative_postition[shifted_indexes] ) + # handle ends of each interval if src_west[-1] > src_west[0]: if shifted_src_west[0] > west_most: shifted_src_west[0] += -360.0 @@ -477,7 +450,7 @@ def _align_axis( return shifted_src_west, shifted_src_east, shift -def _pertub(value: xr.DataArray) -> xr.DataArray: +def _pertub(value: np.ndarray) -> np.ndarray: """ Pertub a value. @@ -486,12 +459,12 @@ def _pertub(value: xr.DataArray) -> xr.DataArray: Parameters ---------- - value : xr.DataArray + value : np.ndarray Value to pertub. Returns ------- - xr.DataArray + np.ndarray Value that's been pertubed. """ if value >= 0.0: @@ -499,8 +472,30 @@ def _pertub(value: xr.DataArray) -> xr.DataArray: else: offset = np.floor(value - 0.000001) + 1.0 - return xr.DataArray(offset) + return offset # vectorize version of pertub _vpertub = np.vectorize(_pertub) + + +def _get_dimension(input_data_var, cf_axis_name): + name = get_dim_keys(input_data_var, cf_axis_name) + + index = input_data_var.dims.index(name) + + return name, index + + +def _get_bounds_ensure_dtype(ds, axis): + try: + name = ds.cf.bounds[axis][0] + except (KeyError, IndexError): + raise RuntimeError(f"Could not determine {axis!r} bounds") + else: + bounds = ds[name] + + if bounds.dtype != np.float32: + bounds = bounds.astype(np.float32) + + return bounds.data