From 7053713c61a4e0dd43c5d30ab730e83f650aedd3 Mon Sep 17 00:00:00 2001 From: Annette Stellema Date: Wed, 21 Aug 2024 15:41:16 +1000 Subject: [PATCH] Fix fit_gev relative_fit_test check and add tests --- unseen/eva.py | 2 +- unseen/tests/test_eva.py | 34 ++++++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/unseen/eva.py b/unseen/eva.py index 921c29b..455b565 100644 --- a/unseen/eva.py +++ b/unseen/eva.py @@ -319,7 +319,7 @@ def _fit( result = check_gev_relative_fit( data, L1, L2, test=relative_fit_test, alpha=alpha ) - if result is False: + if not result: warnings.warn( f"{relative_fit_test} test failed. Returning stationary parameters." ) diff --git a/unseen/tests/test_eva.py b/unseen/tests/test_eva.py index 01947c0..bba533f 100644 --- a/unseen/tests/test_eva.py +++ b/unseen/tests/test_eva.py @@ -235,6 +235,40 @@ def test_fit_ns_gev_3d(): assert np.all(theta.isel(theta=2) > 0) # Positive trend in location +def test_fit_ns_gev_1d_relative_fit_test_bic_trend(): + """Run non-stationary fit & check 'BIC' test returns nonstationary params.""" + data, _ = example_da_gev_1d() + # Add a large positive linear trend + data = add_example_gev_trend(data) + data = add_example_gev_trend(data) + covariate = np.arange(data.time.size, dtype=int) + + theta = fit_gev( + data, + stationary=False, + core_dim="time", + covariate=covariate, + relative_fit_test="bic", + ) + assert np.all(theta[2] > 0) # Positive trend in location + + +def test_fit_ns_gev_1d_relative_fit_test_bic_no_trend(): + """Run non-stationary fit & check 'BIC' test returns stationary params.""" + data, _ = example_da_gev_1d() + covariate = np.arange(data.time.size, dtype=int) + + theta = fit_gev( + data, + stationary=False, + core_dim="time", + covariate=covariate, + relative_fit_test="bic", + ) + assert np.all(theta[2] == 0) # No trend in location + assert np.all(theta[4] == 0) # No trend in scale + + def test_fit_ns_gev_3d_dask(): """Run non-stationary fit using 3D dask array & check results.""" data, _ = example_da_gev_3d_dask()