Skip to content

Commit

Permalink
Fixes docstrings and adds comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonb5 committed Feb 23, 2024
1 parent ad61067 commit 96eaae2
Showing 1 changed file with 43 additions and 8 deletions.
51 changes: 43 additions & 8 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,11 +119,13 @@ def _regrid(
other_sizes = list(other_dims.values())

data_shape = [y_length * x_length] + other_sizes
# output data is always float32 in original code
output_data = np.zeros(data_shape, dtype=np.float32)

is_2d = input_data_var.ndim <= 2

# need to optimize
# TODO: need to optimize further, investigate using ufuncs and dask arrays
# TODO: how common is lon by lat data? may need to reshape
for y in range(y_length):
y_seg = np.take(input_data, lat_mapping[y], axis=y_index)

Expand All @@ -134,6 +136,10 @@ def _regrid(

output_seg_index = y * x_length + x

# using the `out` argument is more performant, places data directly into
# array memory rather than allocating a new variable. wasn't working for
# single element output, needs further investigation as we may not need
# branch
if is_2d:
output_data[output_seg_index] = np.divide(
np.sum(
Expand Down Expand Up @@ -208,9 +214,9 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
Parameters
----------
src : np.ndarray
DataArray containing the source latitude bounds.
Array containing the source latitude bounds.
dst : np.ndarray
DataArray containing the destination latitude bounds.
Array containing the destination latitude bounds.
Returns
-------
Expand All @@ -222,16 +228,23 @@ def _map_latitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:

dst_length = dst_south.shape[0]

# finds contributing source cells for each destination cell based on bounds values
# output is a list of lists containing the contributing cell indexes
# e.g. let src_south be [90, 45, 0, -45], source_north be [45, 0, -45, -90],
# dst_north[x] be 70, and dst_south[x] be -70 then the result would be [[1, 2]]
mapping = [
np.where(np.logical_and(src_south < dst_north[x], src_north > dst_south[x]))[0]
for x in range(dst_length)
]

# finds minimum and maximum bounds for each output cell, considers source and
# destination bounds for each cell
bounds = [
(np.minimum(dst_north[x], src_north[y]), np.maximum(dst_south[x], src_south[y]))
for x, y in enumerate(mapping)
]

# convert latitude to cell weight (difference of length from equator)
weights = [
(np.sin(np.deg2rad(x)) - np.sin(np.deg2rad(y))).reshape((-1, 1))
for x, y in bounds
Expand All @@ -247,9 +260,9 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
Parameters
----------
src : np.ndarray
DataArray containing source longitude bounds.
Array containing source longitude bounds.
dst : np.ndarray
DataArray containing destination longitude bounds.
Array containing destination longitude bounds.
Returns
-------
Expand All @@ -259,6 +272,7 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
src_west, src_east = _extract_bounds(src)
dst_west, dst_east = _extract_bounds(dst)

# align source and destination longitude
shifted_src_west, shifted_src_east, shift = _align_axis(
src_west,
src_east,
Expand All @@ -268,6 +282,8 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
src_length = src_west.shape[0]
dst_length = dst_west.shape[0]

# finds contributing source cells for each destination cell based on bounds values
# output is a list of lists containing the contributing cell indexes
mapping = [
np.where(
np.logical_and(
Expand All @@ -277,6 +293,8 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
for x in range(dst_length)
]

# weights are just the difference between minimum and maximum of contributing bounds
# for each destination cell
weights = [
(
np.minimum(dst_east[x], shifted_src_east[y])
Expand All @@ -285,11 +303,16 @@ def _map_longitude(src: np.ndarray, dst: np.ndarray) -> Tuple[List, List]:
for x, y in enumerate(mapping)
]

# need to adjust the source contributing indexes by the shift required to align
# source and destination longitude
for x in range(len(mapping)):
# shift the mapping indexes by the shift used to determine the weights
mapping[x] += shift

# find the contributing indexes that need to be wrapped
wrapped = np.where(mapping[x] > src_length - 1)[0]

# wrap the contributing index as all indexes must be <src_length
mapping[x][wrapped] -= src_length

return mapping, weights
Expand Down Expand Up @@ -330,30 +353,36 @@ def _align_axis(
Parameters
----------
src_west : np.ndarray
DataArray containing the western source bounds.
Array containing the western source bounds.
src_east : np.ndarray
DataArray containing the eastern source bounds.
Array containing the eastern source bounds.
dst_west : np.ndarray
DataArray containing the western destination bounds.
Array containing the western destination bounds.
Returns
-------
Tuple[np.ndarray, np.ndarray, int]
A tuple containing the shifted western source bounds, the shifted eastern
source bounds, and the number of places shifted to align axis.
"""
# find smallest western bounds
west_most = np.minimum(dst_west[0], dst_west[-1])

# find cell index required to align bounds
alignment_index = _vpertub((west_most - src_west[-1]) / 360.0)

# shift index depending on first/last source bounds
alignment_index = (
alignment_index + 1 if src_west[0] < src_west[-1] else alignment_index - 1
)

# find relative indexes for each source cell to the destinations most western cell
relative_postition = _vpertub((west_most - src_west) / 360.0)

# find all index values that are not the alignment index
src_alignment_index = np.where(relative_postition != alignment_index)[0][0]

# determine the shift factor required to align source and destination bounds
if src_west[0] < src_west[-1]:
if west_most == src_west[src_alignment_index]:
shift = src_alignment_index
Expand All @@ -367,20 +396,26 @@ def _align_axis(

src_length = src_west.shape[0]

# shift the source index values
shifted_indexes = np.arange(src_length + 1) + shift

# find index values that need to be shift to be within 0 - src_length
wrapped = np.where(shifted_indexes > src_length - 1)

# shift the indexes
shifted_indexes[wrapped] -= src_length

# reorder src_west and add portion to align
shifted_src_west = (
src_west[shifted_indexes] + 360.0 * relative_postition[shifted_indexes]
)

# reorder src_east and add portion to align
shifted_src_east = (
src_east[shifted_indexes] + 360.0 * relative_postition[shifted_indexes]
)

# handle ends of each interval
if src_west[-1] > src_west[0]:
if shifted_src_west[0] > west_most:
shifted_src_west[0] += -360.0
Expand Down

0 comments on commit 96eaae2

Please sign in to comment.