From 96eaae21ebdf20fe51831bca24a8a892a6ed6cf0 Mon Sep 17 00:00:00 2001 From: Jason Boutte Date: Fri, 23 Feb 2024 12:58:25 -0800 Subject: [PATCH] Fixes docstrings and adds comments --- xcdat/regridder/regrid2.py | 51 ++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 8 deletions(-) diff --git a/xcdat/regridder/regrid2.py b/xcdat/regridder/regrid2.py index cb4135d9..d8df3372 100644 --- a/xcdat/regridder/regrid2.py +++ b/xcdat/regridder/regrid2.py @@ -119,11 +119,13 @@ def _regrid( other_sizes = list(other_dims.values()) data_shape = [y_length * x_length] + other_sizes + # output data is always float32 in original code output_data = np.zeros(data_shape, dtype=np.float32) is_2d = input_data_var.ndim <= 2 - # need to optimize + # TODO: need to optimize further, investigate using ufuncs and dask arrays + # TODO: how common is lon by lat data? may need to reshape for y in range(y_length): y_seg = np.take(input_data, lat_mapping[y], axis=y_index) @@ -134,6 +136,10 @@ def _regrid( output_seg_index = y * x_length + x + # using the `out` argument is more performant, places data directly into + # array memory rather than allocating a new variable. wasn't working for + # single element output, needs further investigation as we may not need + # branch if is_2d: output_data[output_seg_index] = np.divide( np.sum( @@ -208,9 +214,9 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: Parameters ---------- src : np.ndarray - DataArray containing the source latitude bounds. + Array containing the source latitude bounds. dst : np.ndarray - DataArray containing the destination latitude bounds. + Array containing the destination latitude bounds. Returns ------- @@ -222,16 +228,23 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: dst_length = dst_south.shape[0] + # finds contributing source cells for each destination cell based on bounds values + # output is a list of lists containing the contributing cell indexes + # e.g. let src_south be [90, 45, 0, -45], source_north be [45, 0, -45, -90], + # dst_north[x] be 70, and dst_south[x] be -70 then the result would be [[1, 2]] mapping = [ np.where(np.logical_and(src_south < dst_north[x], src_north > dst_south[x]))[0] for x in range(dst_length) ] + # finds minimum and maximum bounds for each output cell, considers source and + # destination bounds for each cell bounds = [ (np.minimum(dst_north[x], src_north[y]), np.maximum(dst_south[x], src_south[y])) for x, y in enumerate(mapping) ] + # convert latitude to cell weight (difference of length from equator) weights = [ (np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1)) for x, y in bounds @@ -247,9 +260,9 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: Parameters ---------- src : np.ndarray - DataArray containing source longitude bounds. + Array containing source longitude bounds. dst : np.ndarray - DataArray containing destination longitude bounds. + Array containing destination longitude bounds. Returns ------- @@ -259,6 +272,7 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: src_west, src_east = _extract_bounds(src) dst_west, dst_east = _extract_bounds(dst) + # align source and destination longitude shifted_src_west, shifted_src_east, shift = _align_axis( src_west, src_east, @@ -268,6 +282,8 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: src_length = src_west.shape[0] dst_length = dst_west.shape[0] + # finds contributing source cells for each destination cell based on bounds values + # output is a list of lists containing the contributing cell indexes mapping = [ np.where( np.logical_and( @@ -277,6 +293,8 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: for x in range(dst_length) ] + # weights are just the difference between minimum and maximum of contributing bounds + # for each destination cell weights = [ ( np.minimum(dst_east[x], shifted_src_east[y]) @@ -285,11 +303,16 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]: for x, y in enumerate(mapping) ] + # need to adjust the source contributing indexes by the shift required to align + # source and destination longitude for x in range(len(mapping)): + # shift the mapping indexes by the shift used to determine the weights mapping[x] += shift + # find the contributing indexes that need to be wrapped wrapped = np.where(mapping[x] > src_length - 1)[0] + # wrap the contributing index as all indexes must be src_length - 1) + # shift the indexes shifted_indexes[wrapped] -= src_length + # reorder src_west and add portion to align shifted_src_west = ( src_west[shifted_indexes] + 360.0 * relative_postition[shifted_indexes] ) + # reorder src_east and add portion to align shifted_src_east = ( src_east[shifted_indexes] + 360.0 * relative_postition[shifted_indexes] ) + # handle ends of each interval if src_west[-1] > src_west[0]: if shifted_src_west[0] > west_most: shifted_src_west[0] += -360.0