Skip to content

Commit

Permalink
[PR]: Improving regrid2 performance (#533)
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonb5 authored Mar 8, 2024
1 parent 78dedf0 commit 38cdfd5
Show file tree
Hide file tree
Showing 2 changed files with 294 additions and 323 deletions.
78 changes: 27 additions & 51 deletions tests/test_regrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,29 +384,28 @@ def test_vertical_placeholder(self):
with pytest.raises(NotImplementedError, match=""):
regridder.vertical("so", ds)

def test_missing_dimension(self):
ds = fixtures.generate_dataset(
decode_times=True, cf_compliant=False, has_bounds=True
)

del ds.lat.attrs["axis"]
@pytest.mark.filterwarnings("ignore:.*invalid value.*divide.*:RuntimeWarning")
def test_output_bounds(self):
ds = self.coarse_3d_ds

output_grid = grid.create_gaussian_grid(32)

regridder = regrid2.Regrid2Regridder(ds, output_grid)

with pytest.raises(
RuntimeError,
match="Could not find axis 'lat', ensure 'lat' exists and the attributes are correct.",
):
regridder.horizontal("ts", ds)
output_ds = regridder.horizontal("ts", ds)

assert "lat_bnds" in output_ds
assert "lon_bnds" in output_ds
assert "time_bnds" in output_ds

@pytest.mark.filterwarnings("ignore:.*invalid value.*divide.*:RuntimeWarning")
def test_output_bounds(self):
def test_output_bounds_missing_temporal(self):
ds = fixtures.generate_dataset(
decode_times=True, cf_compliant=False, has_bounds=True
)

ds = self.coarse_3d_ds.drop("time_bnds")

output_grid = grid.create_gaussian_grid(32)

regridder = regrid2.Regrid2Regridder(ds, output_grid)
Expand All @@ -415,7 +414,7 @@ def test_output_bounds(self):

assert "lat_bnds" in output_ds
assert "lon_bnds" in output_ds
assert "time_bnds" in output_ds
assert "time_bnds" not in output_ds

@pytest.mark.parametrize(
"src,dst,expected_west,expected_east,expected_shift",
Expand Down Expand Up @@ -499,45 +498,16 @@ def test_regrid_input_mask(self):

expected_output = np.array(
[
[0.0, 0.0, 0.0, 0.0],
[0.70710677, 0.70710677, 0.70710677, 0.70710677],
[0.70710677, 0.70710677, 0.70710677, 0.70710677],
[0.0, 0.0, 0.0, 0.0],
[0.0] * 4,
[0.70710677] * 4,
[0.70710677] * 4,
[0.0] * 4,
],
dtype=np.float32,
)

assert np.all(output_data.ts.values == expected_output)

def test_regrid_output_mask(self):
output_mask = [
[0, 0, 0, 0],
[1, 1, 1, 1],
[1, 1, 1, 1],
[0, 0, 0, 0],
]

self.fine_2d_ds["mask"] = (("lat", "lon"), output_mask)

regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)

output_data = regridder.horizontal("ts", self.coarse_2d_ds)

expected_output = np.array(
[
[1.0, 1.0, 1.0, 1.0],
[1e20, 1e20, 1e20, 1e20],
[1e20, 1e20, 1e20, 1e20],
[1.0, 1.0, 1.0, 1.0],
],
dtype=np.float32,
)

# need to replace nans since nan != nan
output_data["ts"] = output_data.ts.fillna(1e20)

assert np.all(output_data.ts.values == expected_output)

def test_preserve_attrs(self):
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)

Expand All @@ -547,7 +517,7 @@ def test_preserve_attrs(self):
assert output_data["ts"].attrs == self.da_attrs

for x in output_data.coords:
assert output_data[x].attrs == self.coarse_2d_ds[x].attrs
assert output_data[x].attrs == self.coarse_2d_ds[x].attrs, f"{x}"

def test_regrid_2d(self):
regridder = regrid2.Regrid2Regridder(self.coarse_2d_ds, self.fine_2d_ds)
Expand Down Expand Up @@ -582,7 +552,7 @@ def test_regrid_4d(self):

def test_map_longitude_coarse_to_fine(self):
mapping, weights = regrid2._map_longitude(
self.coarse_lon_bnds, self.fine_lon_bnds
self.coarse_lon_bnds.values, self.fine_lon_bnds.values
)

expected_mapping = [
Expand All @@ -604,7 +574,7 @@ def test_map_longitude_coarse_to_fine(self):

def test_map_longitude_fine_to_coarse(self):
mapping, weights = regrid2._map_longitude(
self.fine_lon_bnds, self.coarse_lon_bnds
self.fine_lon_bnds.values, self.coarse_lon_bnds.values
)

expected_mapping = [
Expand All @@ -619,7 +589,7 @@ def test_map_longitude_fine_to_coarse(self):

def test_map_latitude_coarse_to_fine(self):
mapping, weights = regrid2._map_latitude(
self.coarse_lat_bnds, self.fine_lat_bnds
self.coarse_lat_bnds.values, self.fine_lat_bnds.values
)

expected_mapping = [
Expand Down Expand Up @@ -648,7 +618,7 @@ def test_map_latitude_coarse_to_fine(self):

def test_map_latitude_fine_to_coarse(self):
mapping, weights = regrid2._map_latitude(
self.fine_lat_bnds, self.coarse_lat_bnds
self.fine_lat_bnds.values, self.coarse_lat_bnds.values
)

expected_mapping = [
Expand Down Expand Up @@ -684,6 +654,12 @@ def test_reversed_extract_bounds(self):
assert north.shape == (3,)
assert north[0], north[-1] == (60, 90)

def test_get_bounds_ensure_dtype(self):
del self.coarse_2d_ds.lon.attrs["bounds"]

with pytest.raises(RuntimeError):
regrid2._get_bounds_ensure_dtype(self.coarse_2d_ds, "X")


class TestXESMFRegridder:
@pytest.fixture(autouse=True)
Expand Down
Loading

0 comments on commit 38cdfd5

Please sign in to comment.