Skip to content

Commit

Permalink
Add non-xarray and stacked ensemble tests to test_eva.py
Browse files Browse the repository at this point in the history
  • Loading branch information
stellema committed Nov 21, 2023
1 parent 76d0325 commit e66111e
Showing 1 changed file with 68 additions and 19 deletions.
87 changes: 68 additions & 19 deletions unseen/tests/test_eva.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,34 +74,68 @@ def add_example_gev_trend(data):
return data + trend


def example_da_gev_forecast():
"""An example 2D multi-ensemble trended GEV data array and its distribution parameters."""
data, theta = example_da_gev_1d()
shape, loc, scale = theta

da_list = []
for i in range(5):
rvs = genextreme.rvs(
shape, loc=loc, scale=scale, size=(data.time.size), random_state=i
)
da = xr.DataArray(
[
rvs,
],
coords=dict(ensemble=[i], time=data.time),
dims=["ensemble", "time"],
)
da = add_example_gev_trend(da)
da_list.append(da)
data = xr.concat(da_list, "ensemble")

data_stacked = data.stack({"sample": ["time", "ensemble"]})
return data_stacked, theta


def test_fit_gev_1d():
"""Run stationary fit using 1D array & check results."""
data, theta_i = example_da_gev_1d()
theta = fit_gev(data, stationary=True, check_fit=False)
theta = fit_gev(data, stationary=True, time_dim="time", check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_gev_1d_numpy():
"""Run stationary fit using 1D np.ndarray & check results."""
data, theta_i = example_da_gev_1d()
data = data.values
theta = fit_gev(data, stationary=True, time_dim=None, check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_gev_1d_dask():
"""Run stationary fit using 1D dask array & check results."""
data, theta_i = example_da_gev_1d_dask()
theta = fit_gev(data, stationary=True, check_fit=False)
theta = fit_gev(data, stationary=True, time_dim="time", check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_gev_3d():
"""Run stationary fit using 3D array & check results."""
data, theta_i = example_da_gev_3d()
theta = fit_gev(data, stationary=True, check_fit=False)
theta = fit_gev(data, stationary=True, time_dim="time", check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_gev_3d_dask():
"""Run stationary fit using 3D dask array & check results."""
data, theta_i = example_da_gev_3d_dask()
theta = fit_gev(data, stationary=True, check_fit=False)
theta = fit_gev(data, stationary=True, time_dim="time", check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)

Expand All @@ -111,52 +145,67 @@ def test_fit_ns_gev_1d():
data, _ = example_da_gev_1d()
data = add_example_gev_trend(data)

# Check function runs.
theta = fit_gev(data, stationary=False, check_fit=False)
theta = fit_gev(data, stationary=False, time_dim="time", check_fit=False)
pvalue = check_gev_fit(data, theta, time_dim="time")
assert np.all(pvalue) > alpha


check_gev_fit(data, theta, time_dim="time")
def test_fit_ns_gev_1d_numpy():
"""Run non-stationary fit using 1D np.ndarray & check results."""
data, _ = example_da_gev_1d()
data = add_example_gev_trend(data)
data = data.values

theta = fit_gev(data, stationary=False, time_dim=None, check_fit=False)
pvalue = check_gev_fit(data, theta, time_dim="time")
assert np.all(pvalue) > alpha


def test_fit_ns_gev_1d_dask():
"""Run non-stationary fit using 1D dask array & check results."""
data, _ = example_da_gev_1d_dask()

# Add a positive linear trend.
data = add_example_gev_trend(data)

# Check function runs.
theta = fit_gev(data, stationary=False, check_fit=False)

theta = fit_gev(data, stationary=False, time_dim="time", check_fit=False)
pvalue = check_gev_fit(data, theta, time_dim="time")
assert np.all(pvalue) > alpha


def test_fit_ns_gev_3d():
"""Run non-stationary fit using 3D array & check results."""
data, _ = example_da_gev_3d()

# Add a positive linear trend.
data = add_example_gev_trend(data)

# Check function runs.
theta = fit_gev(data, stationary=False, check_fit=False)

theta = fit_gev(data, stationary=False, time_dim="time", check_fit=False)
pvalue = check_gev_fit(data, theta, time_dim="time")
assert np.all(pvalue) > alpha


def test_fit_ns_gev_3d_dask():
"""Run non-stationary fit using 3D dask array & check results."""
data, _ = example_da_gev_3d_dask()

# Add a positive linear trend.
data = add_example_gev_trend(data)

# Check function runs.
theta = fit_gev(data, stationary=False, check_fit=False)

theta = fit_gev(data, stationary=False, time_dim="time", check_fit=False)
pvalue = check_gev_fit(data, theta, time_dim="time")
assert np.all(pvalue) > alpha


def test_fit_ns_gev_forecast():
"""Run stationary fit using 1D array & check results."""
data, theta_i = example_da_gev_forecast()
theta = fit_gev(data, stationary=False, time_dim="sample", check_fit=False)

# Check fitted params match params used to create data (might fail due to trend).
shape, loc, loc1, scale, scale1 = theta
npt.assert_allclose((shape, loc, scale), theta_i, rtol=rtol)

# Check it fitted a positive trend.
assert np.all(loc1) > 0
assert np.all(scale1) > 0

pvalue = check_gev_fit(data, theta, time_dim="sample")
assert np.all(pvalue) > alpha

0 comments on commit e66111e

Please sign in to comment.