Skip to content

Commit

Permalink
Update test_eva.py
Browse files Browse the repository at this point in the history
Add non stationary gev fit tests
  • Loading branch information
stellema committed Nov 10, 2023
1 parent 7f3471b commit dfc770a
Showing 1 changed file with 123 additions and 24 deletions.
147 changes: 123 additions & 24 deletions unseen/tests/test_eva.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import pytest

import numpy as np
import xarray as xr

Expand All @@ -8,30 +6,35 @@

from unseen.general_utils import fit_gev, check_gev_fit

rtol = 0.25 # relative tolerance
rtol = 0.3 # relative tolerance
alpha = 0.05
np.random.seed(0)


def example_da_gev_1d():
"""An example 1D timeseries DataArray"""
"""An example 1D GEV DataArray and its distribution parameters."""
time = np.arange("2000-01-01", "2002-01-01", dtype=np.datetime64)

c = np.random.rand()
loc = np.random.randint(-10, 10) + np.random.rand()
scale = np.random.rand()
theta = c, loc, scale
# Generate shape, location and scale parameters.
shape = np.random.uniform()
loc = np.random.uniform(-10, 10)
scale = np.random.uniform(0.1, 10)
theta = shape, loc, scale

rvs = genextreme.rvs(c, loc=loc, scale=scale, size=(time.size), random_state=0)
data = xr.DataArray(rvs, coords=[time], dims=['time'])
rvs = genextreme.rvs(shape, loc=loc, scale=scale, size=(time.size), random_state=0)
data = xr.DataArray(rvs, coords=[time], dims=["time"])
return data, theta


def example_da_gev_1d_dask():
"""An example 1D GEV dask array and its distribution parameters."""
data, theta = example_da_gev_1d()
data = data.chunk()
data = data.chunk({"time": -1})
return data, theta


def example_da_gev_3d():
"""An example multi-dim timeseries DataArray"""
"""An example 3D GEV DataArray and its distribution parameters."""
time = np.arange("2000-01-01", "2002-01-01", dtype=np.datetime64)
lats = np.arange(0, 2)
lons = np.arange(0, 2)
Expand All @@ -42,40 +45,136 @@ def example_da_gev_3d():
scale = np.random.rand(*shape)
theta = np.stack([c, loc, scale], axis=-1)

rvs = genextreme.rvs(c, loc=loc, scale=scale, size=(time.size, *shape), random_state=0)
data = xr.DataArray(rvs, coords=[time, lats, lons], dims=['time', 'lat', 'lon'])
rvs = genextreme.rvs(
c, loc=loc, scale=scale, size=(time.size, *shape), random_state=0
)
data = xr.DataArray(rvs, coords=[time, lats, lons], dims=["time", "lat", "lon"])
return data, theta


def example_da_gev_3d_dask():
"""An example 3D GEV dask array and its distribution parameters."""
data, theta = example_da_gev_3d()
data = data.chunk()
data = data.chunk({"time": -1, "lat": 1, "lon": 1})
return data, theta


def add_example_gev_trend(data):
trend = np.arange(data.time.size) / data.time.size
trend = xr.DataArray(trend, coords={"time": data.time})
return data + trend


def test_fit_gev_1d():
# 1D data matches given parameters.
"""Run stationary fit using 1D array & check results."""
data, theta_i = example_da_gev_1d()
theta = fit_gev(data, stationary=True, check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_gev_1d_dask():
# 1D chunked data matches given parameters.
"""Run stationary fit using 1D dask array & check results."""
data, theta_i = example_da_gev_1d_dask()
theta = fit_gev(data, stationary=True, check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_gev_3d():
# 3D data matches given parameters.
"""Run stationary fit using 3D array & check results."""
data, theta_i = example_da_gev_3d()
theta = fit_gev(data, stationary=True, check_fit=True)
theta = fit_gev(data, stationary=True, check_fit=False)
# Check fitted params match params used to create data.
npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_gev_3d_dask():
# 3D chunked data matches given parameters.
data, theta_i = example_da_gev_3d_dask()
theta = fit_gev(data, stationary=True, check_fit=True)
npt.assert_allclose(theta, theta_i, rtol=rtol)
# def test_fit_gev_3d_dask():
# """Run stationary fit using 3D dask array & check results."""
# data, theta_i = example_da_gev_3d_dask()
# theta = fit_gev(data, stationary=True, check_fit=False)
# # Check fitted params match params used to create data.
# npt.assert_allclose(theta, theta_i, rtol=rtol)


def test_fit_ns_gev_1d():
"""Run stationary fit using 1D array & check results."""
data, theta_i = example_da_gev_1d()
data = add_example_gev_trend(data)

# Check function runs.
theta = fit_gev(data, stationary=False, check_fit=False)

check_gev_fit(data, theta, time_dim="time")

# Check fitted params match params used to create data.
shape, loc, loc1, scale, scale1 = theta
npt.assert_allclose((shape, loc, scale), theta_i, rtol=rtol)

# Check it detected a positive treand
assert loc1 > 0
assert scale1 > 0


def test_fit_ns_gev_1d_dask():
"""Run stationary fit using 1D dask array & check results."""
data, theta_i = example_da_gev_1d_dask()

# Add a positive linear trend.
data = add_example_gev_trend(data)

# Check function runs.
theta = fit_gev(data, stationary=False, check_fit=False)

# Check fitted params match params used to create data.
shape, loc, loc1, scale, scale1 = theta
npt.assert_allclose((shape, loc, scale), theta_i, rtol=rtol)

# Check it fitted a positive trend.
assert np.all(loc1) > 0
assert np.all(scale1) > 0

pvalue = check_gev_fit(data, theta, time_dim="time")
assert np.all(pvalue) > alpha


def test_fit_ns_gev_3d():
"""Run stationary fit using 3D array & check results."""
data, theta_i = example_da_gev_3d()

# Add a positive linear trend.
data = add_example_gev_trend(data)

# Check function runs.
theta = fit_gev(data, stationary=False, check_fit=False)

# Check fitted params match params used to create data.
npt.assert_allclose(theta.isel(theta=[0, 1, 3]), theta_i, rtol=rtol)

# Check it fitted a positive trend.
assert np.all(theta.isel(theta=2)) > 0
assert np.all(theta.isel(theta=4)) > 0

pvalue = check_gev_fit(data, theta, time_dim="time")
assert np.all(pvalue) > alpha


# def test_fit_gev_3d_dask():
# """Run stationary fit using 3D dask array & check results."""
# data, theta_i = example_da_gev_3d_dask()

# # Add a positive linear trend.
# data = add_example_gev_trend(data)

# # Check function runs.
# theta = fit_gev(data, stationary=False, check_fit=False)

# # Check fitted params match params used to create data.
# npt.assert_allclose(theta.isel(theta=[0, 1, 3]), theta_i, rtol=rtol)

# # Check it fitted a positive trend.
# assert np.all(theta.isel(theta=2)) > 0
# assert np.all(theta.isel(theta=4)) > 0

# pvalue = check_gev_fit(data, theta, time_dim="time")
# assert np.all(pvalue) > alpha

0 comments on commit dfc770a

Please sign in to comment.