Skip to content

Commit

Permalink
Add tests for rse and windowed_max_run_sum
Browse files Browse the repository at this point in the history
  • Loading branch information
seblehner committed Sep 16, 2024
1 parent 302fe27 commit b49d644
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/test_run_length.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,35 @@ def test_rle(ufunc, use_dask, index):
np.testing.assert_array_equal(out, expected)


@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("index", ["first", "last"])
def test_rse(ufunc, use_dask, index):
if use_dask and ufunc:
pytest.xfail("rse_1d is not implemented for dask arrays.")

values = np.zeros((10, 365, 4, 4))
time = pd.date_range("2000-01-01", periods=365, freq="D")
values[:, 1:11, ...] = 30
da = xr.DataArray(values, coords={"time": time}, dims=("a", "time", "b", "c"))

if ufunc:
pytest.xfail("rse_1d is not implemented.")
else:
if use_dask:
da = da.chunk({"a": 1, "b": 2})

out = rl.rse(da, index=index).mean(["a", "b", "c"])
if index == "last":
expected = np.zeros(365)
expected[1:10] = np.nan
expected[10] = 300
else:
expected = np.zeros(365)
expected[1] = 300
expected[2:11] = np.nan
np.testing.assert_array_equal(out, expected)


@pytest.mark.parametrize("use_dask", [True, False])
@pytest.mark.parametrize("index", ["first", "last"])
def test_extract_events_identity(use_dask, index):
Expand Down Expand Up @@ -370,6 +399,16 @@ def test_simple(self, index):
) + len(a[34:45])


class TestWindowedMaxRunSum:
@pytest.mark.parametrize("index", ["first", "last"])
def test_simple(self, index):
a = xr.DataArray(np.zeros(50, float), dims=("time",))
a[4:6] = 5 # too short
a[25:30] = 5 # long enough, but not max
a[35:45] = 5 # max sum => yields 10*5
assert rl.windowed_max_run_sum(a, 3, dim="time", index=index) == 50


class TestLastRun:
@pytest.mark.parametrize(
"coord,expected",
Expand Down

0 comments on commit b49d644

Please sign in to comment.