Skip to content

Commit

Permalink
Merge pull request #305 from nspope/failsafe-2f1
Browse files Browse the repository at this point in the history
Numeric fixes for extremely large shape parameters
  • Loading branch information
hyanwong authored Aug 1, 2023
2 parents 667a49d + 4b3cd9f commit 41c2689
Show file tree
Hide file tree
Showing 5 changed files with 514 additions and 152 deletions.
177 changes: 164 additions & 13 deletions tests/test_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import numpy as np
import pytest
import scipy.integrate
import scipy.special
import scipy.stats
from distribution_functions import conditional_coalescent_pdf
from distribution_functions import kl_divergence
Expand All @@ -44,18 +45,6 @@
]


def approximate_gamma_mom(mean, variance):
"""
Use the method of moments to approximate a distribution with a gamma of the
same mean and variance
"""
assert mean > 0
assert variance > 0
alpha = mean**2 / variance
beta = mean / variance
return alpha, beta


@pytest.mark.parametrize("pars", _gamma_trio_test_cases)
class TestPosteriorMomentMatching:
"""
Expand Down Expand Up @@ -129,6 +118,60 @@ def test_sufficient_statistics(self, pars):
)[0]
assert np.isclose(ln_t_j, ck_ln_t_j, rtol=1e-3)

def test_mean_and_variance(self, pars):
logconst, t_i, var_t_i, t_j, var_t_j = approx.mean_and_variance(*pars)
ck_normconst = scipy.integrate.dblquad(
lambda ti, tj: self.pdf(ti, tj, *pars),
0,
np.inf,
lambda tj: tj,
np.inf,
epsabs=0,
)[0]
assert np.isclose(logconst, np.log(ck_normconst), rtol=1e-3)
ck_t_i = scipy.integrate.dblquad(
lambda ti, tj: ti * self.pdf(ti, tj, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_i, ck_t_i, rtol=1e-3)
ck_t_j = scipy.integrate.dblquad(
lambda ti, tj: tj * self.pdf(ti, tj, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
np.inf,
epsabs=0,
)[0]
assert np.isclose(t_j, ck_t_j, rtol=1e-3)
ck_var_t_i = (
scipy.integrate.dblquad(
lambda ti, tj: ti**2 * self.pdf(ti, tj, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
np.inf,
epsabs=0,
)[0]
- ck_t_i**2
)
assert np.isclose(var_t_i, ck_var_t_i, rtol=1e-3)
ck_var_t_j = (
scipy.integrate.dblquad(
lambda ti, tj: tj**2 * self.pdf(ti, tj, *pars) / ck_normconst,
0,
np.inf,
lambda tj: tj,
np.inf,
epsabs=0,
)[0]
- ck_t_j**2
)
assert np.isclose(var_t_j, ck_var_t_j, rtol=1e-3)

def test_approximate_gamma(self, pars):
_, t_i, ln_t_i, t_j, ln_t_j = approx.sufficient_statistics(*pars)
alpha_i, beta_i = approx.approximate_gamma_kl(t_i, ln_t_i)
Expand Down Expand Up @@ -183,7 +226,7 @@ def test_approximate_gamma(self, k):
x = self.priors[self.n][k][mean_column]
xvar = self.priors[self.n][k][var_column]
# match mean/variance
alpha_0, beta_0 = approximate_gamma_mom(x, xvar)
alpha_0, beta_0 = approx.approximate_gamma_mom(x, xvar)
ck_x = alpha_0 / beta_0
ck_xvar = alpha_0 / beta_0**2
assert np.isclose(x, ck_x)
Expand All @@ -205,3 +248,111 @@ def test_approximate_gamma(self, k):
lambda x: scipy.stats.gamma.logpdf(x, alpha_1, scale=1 / beta_1),
)
assert kl_1 < kl_0


@pytest.mark.parametrize(
"pars",
[
[1.62, 0.00074, 25603.8, 0.6653, 0.0, 0.0011], # "Cancellation error"
],
)
class Test2F1Failsafe:
"""
Test approximation of marginal pairwise joint distributions by a gamma via
arbitrary precision mean/variance matching, when sufficient statistics
calculation fails
"""

def test_sufficient_statistics_throws_exception(self, pars):
with pytest.raises(Exception, match="Cancellation error"):
approx.sufficient_statistics(*pars)

def test_exception_uses_mean_and_variance(self, pars):
_, t_i, va_t_i, t_j, va_t_j = approx.mean_and_variance(*pars)
ai1, bi1 = approx.approximate_gamma_mom(t_i, va_t_i)
aj1, bj1 = approx.approximate_gamma_mom(t_j, va_t_j)
_, par_i, par_j = approx.gamma_projection(*pars)
ai2, bi2 = par_i
aj2, bj2 = par_j
assert np.isclose(ai1, ai2)
assert np.isclose(bi1, bi2)
assert np.isclose(aj1, aj2)
assert np.isclose(bj1, bj2)


class TestGammaFactorization:
"""
Test various functions for manipulating factorizations of gamma distributions
"""

def test_rescale_gamma(self):
# posterior_shape = prior_shape + sum(in_shape - 1) + sum(out_shape - 1)
# posterior_rate = prior_rate + sum(in_rate) + sum(out_rate)
in_message = np.array([[1.5, 0.25], [1.5, 0.25]])
out_message = np.array([[1.5, 0.25], [1.5, 0.25]])
posterior = np.array([4, 1.5]) # prior is implicitly [2, 0.5]
prior = np.array(
[
posterior[0]
- np.sum(in_message[:, 0] - 1)
- np.sum(out_message[:, 0] - 1),
posterior[1] - np.sum(in_message[:, 1]) - np.sum(out_message[:, 1]),
]
)
# rescale
target_shape = 12
new_post, new_in, new_out = approx.rescale_gamma(
posterior, in_message, out_message, target_shape
)
new_prior = np.array(
[
new_post[0] - np.sum(new_in[:, 0] - 1) - np.sum(new_out[:, 0] - 1),
new_post[1] - np.sum(new_in[:, 1]) - np.sum(new_out[:, 1]),
]
)
print(prior, new_prior)
assert new_post[0] == target_shape
# mean is conserved
assert np.isclose(new_post[0] / new_post[1], posterior[0] / posterior[1])
# magnitude of messages (in natural parameterization) is conserved
assert np.isclose(
(new_prior[0] - 1) / np.sum(new_in[:, 0] - 1),
(prior[0] - 1) / np.sum(in_message[:, 0] - 1),
)
assert np.isclose(
new_prior[1] / np.sum(new_in[:, 1]),
prior[1] / np.sum(in_message[:, 1]),
)

def test_average_gammas(self):
# E[x] = shape/rate
# E[log x] = digamma(shape) - log(rate)
shape = np.array([0.5, 1.5])
rate = np.array([1.0, 1.0])
avg_shape, avg_rate = approx.average_gammas(shape, rate)
E_x = np.mean(shape)
E_logx = np.mean(scipy.special.digamma(shape))
assert np.isclose(E_x, avg_shape / avg_rate)
assert np.isclose(E_logx, scipy.special.digamma(avg_shape) - np.log(avg_rate))


class TestKLMinimizationFailed:
"""
Test errors in KL minimization
"""

def test_violates_jensen(self):
with pytest.raises(approx.KLMinimizationFailed, match="violates Jensen's"):
approx.approximate_gamma_kl(1, 0)

def test_asymptotic_bound(self):
# check that bound is returned over threshold (rather than optimization)
logx = -0.000001
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha_bound = -0.5 / logx
assert alpha == alpha_bound and alpha > 1e4
# check that bound matches optimization result just under threshold
logx = -0.000051
alpha, _ = approx.approximate_gamma_kl(1, logx)
alpha_bound = -0.5 / logx
assert np.abs(alpha - alpha_bound) < 1 and alpha < 1e4
63 changes: 40 additions & 23 deletions tests/test_hypergeo.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@
from tsdate import hypergeo


@pytest.mark.parametrize("x", [1e-10, 1e-6, 1e-2, 1e1, 1e2, 1e3, 1e5, 1e10])
@pytest.mark.parametrize("x", [-0.3, 1e-10, 1e-6, 1e-2, 1e1, 1e2, 1e3, 1e5, 1e10])
class TestPolygamma:
"""
Test numba-fied gamma functions
"""

def test_gammaln(self, x):
assert np.isclose(hypergeo._gammaln(x), float(mpmath.loggamma(x)))
assert np.isclose(hypergeo._gammaln(x), float(mpmath.re(mpmath.loggamma(x))))

def test_digamma(self, x):
assert np.isclose(hypergeo._digamma(x), float(mpmath.psi(0, x)))
Expand All @@ -51,7 +51,8 @@ def test_trigamma(self, x):

def test_betaln(self, x):
assert np.isclose(
hypergeo._betaln(x, 2 * x), float(mpmath.log(mpmath.beta(x, 2 * x)))
hypergeo._betaln(x, 2 * x),
float(mpmath.re(mpmath.log(mpmath.beta(x, 2 * x)))),
)


Expand Down Expand Up @@ -105,9 +106,9 @@ def test_2f1_grad(self, pars):
list(
itertools.product(
[0.8, 20.3, 200.2],
[0.0, 1.0, 10.0, 51.0],
[0.0, 1.0, 10.0, 31.0],
[1.6, 30.5, 300.7],
[1.1, 1.5, 1.9],
[1.1, 1.5, 1.9, 4.2],
)
),
)
Expand Down Expand Up @@ -186,11 +187,11 @@ def _2f1(a, b, c, z):

def test_is_valid_2f1(self, pars):
dz, d2z = self._2f1(*pars)
assert hypergeo._is_valid_2f1(dz, d2z, *pars)
assert hypergeo._is_valid_2f1(dz, d2z, *pars, 1e-10)
# perturb solution to differential equation
dz *= 1 + 1e-3
d2z *= 1 - 1e-3
assert not hypergeo._is_valid_2f1(dz, d2z, *pars)
assert not hypergeo._is_valid_2f1(dz, d2z, *pars, 1e-10)


@pytest.mark.parametrize("muts", [0.0, 1.0, 5.0, 10.0])
Expand Down Expand Up @@ -243,25 +244,41 @@ def test_2f1_grad(self, muts, hyp2f1_func, pars):


@pytest.mark.parametrize(
"pars",
"func, pars, err",
[
# taken from examples in issues tsdate/286, tsdate/289
[1.104, 0.0001125, 118.1396, 0.009052, 1.0, 0.001404],
[2.7481, 0.001221, 344.94083, 0.02329, 3.0, 0.00026624],
[
hypergeo._hyp2f1_dlmf1583,
[-21.62, 0.00074, 1003.8, 0.7653, 100.0, 0.0011],
"Cancellation error",
],
[
hypergeo._hyp2f1_dlmf1583,
[1.62, 0.00074, 25603.8, 0.6653, 0.0, 0.0011],
"Cancellation error",
],
# TODO: gives zero function value, then reroutes through dlmf1581
# [
# hypergeo._hyp2f1_dlmf1583,
# [9007.39, 0.241, 10000, 0.2673, 2.0, 0.01019],
# "Cancellation error",
# ],
[
hypergeo._hyp2f1_dlmf1581,
[1.62, 0.00074, 25603.8, 0.7653, 100.0, 0.0011],
"Maximum terms",
],
[
hypergeo._hyp2f1_dlmf1583,
[1.0, 1.0, 1.0, 1.0, 3.0, 0.0],
"Zero division",
],
],
)
class TestSingular2F1:
class TestInvalid2F1:
"""
Test detection of cases where 2F1 is close to singular and DLMF 15.8.3
suffers from catastrophic cancellation: in these cases, use DLMF 15.8.1
even though it takes much longer to converge.
Test cases where homegrown 2F1 fails to converge
"""

def test_dlmf1583_throws_exception(self, pars):
with pytest.raises(Exception, match="is singular"):
hypergeo._hyp2f1_dlmf1583(*pars)

def test_exception_uses_dlmf1581(self, pars):
v1, *_ = hypergeo._hyp2f1(*pars)
v2, *_ = hypergeo._hyp2f1_dlmf1581(*pars)
assert np.isclose(v1, v2)
def test_hyp2f1_error(self, func, pars, err):
with pytest.raises(hypergeo.Invalid2F1, match=err):
func(*pars)
Loading

0 comments on commit 41c2689

Please sign in to comment.