From 9668590b8bf6f9284bb20bf319e37ce31fc60bef Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 8 Aug 2024 17:22:15 -0300 Subject: [PATCH 01/11] [WIP] add khat and ps_min_ss --- src/arviz_stats/base/khat.py | 167 +++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 src/arviz_stats/base/khat.py diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py new file mode 100644 index 0000000..26ad042 --- /dev/null +++ b/src/arviz_stats/base/khat.py @@ -0,0 +1,167 @@ +import warnings +import numpy as np +from arviz import ess + +def pareto_khat(x, r_eff=None, tail="both", log_weights=False): + """ + + parameters + ---------- + x : DataArray + """ + ary = x.values.flatten() + + if log_weights: + tail = "right" + + ndraws = len(ary) + + if r_eff is None: + r_eff = ess(x.values, method="tail") / ndraws + + if ndraws > 255: + ndraws_tail = np.ceil(3 * (ndraws / r_eff)**0.5).astype(int) + else: + ndraws_tail = int(ndraws / 5) + + if tail == "both": + if ndraws_tail > ndraws / 2: + warnings.warn("Number of tail draws cannot be more than half " + "the total number of draws if both tails are fit, " + f"changing to {ndraws / 2}") + ndraws_tail = ndraws / 2 + + + if ndraws_tail < 5: + warnings.warn("Number of tail draws cannot be less than 5. " + "Changing to 5") + ndraws_tail = 5 + + k = max([pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] for t in ("left", "right")]) + else: + _, k = pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=tail) + + + return k + +def ps_min_ss(k): + if k < 1: + return 10**(1 / (1 - max(0, k))) + else: + return np.inf + +def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', log_weights=False): + if log_weights: + x = x - np.max(x) + + if tail not in ['right', 'left', 'both']: + raise ValueError('tail must be one of "right", "left", or "both"') + + tail_ids = np.arange(ndraws - ndraws_tail, ndraws) + + if tail == 'left': + x = -x + + ordered = np.argsort(x) + draws_tail = x[ordered[tail_ids]] + + cutoff = x[ordered[tail_ids[0] - 1]] # largest value smaller than tail values + + max_tail = np.max(draws_tail) + min_tail = np.min(draws_tail) + + if ndraws_tail >= 5: + if abs(max_tail - min_tail) < np.finfo(float).tiny: + raise ValueError('All tail values are the same') + + if log_weights: + draws_tail = np.exp(draws_tail) + cutoff = np.exp(cutoff) + + k, sigma = _gpdfit(draws_tail - cutoff) + + if np.isfinite(k) and smooth_draws: + p = (np.arange(0.5, ndraws_tail)) / ndraws_tail + smoothed = _gpinv(p, k, sigma, cutoff) + + if log_weights: + smoothed = np.log(smoothed) + else: + smoothed = None + else: + raise ValueError('ndraws_tail must be at least 5') + + if smoothed is not None: + smoothed[smoothed > max_tail] = max_tail + x[ordered[tail_ids]] = smoothed + + if tail == 'left': + x = -x + + return x, k + + +def _gpdfit(ary): + """Estimate the parameters for the Generalized Pareto Distribution (GPD). + + Empirical Bayes estimate for the parameters of the generalized Pareto + distribution given the data. + + Parameters + ---------- + ary: array + sorted 1D data array + + Returns + ------- + k: float + estimated shape parameter + sigma: float + estimated scale parameter + """ + prior_bs = 3 + prior_k = 10 + n = len(ary) + m_est = 30 + int(n**0.5) + + b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5)) + b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1] + b_ary += 1 / ary[-1] + + k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1) # pylint: disable=no-member + len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1) + weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1) + + # remove negligible weights + real_idxs = weights >= 10 * np.finfo(float).eps + if not np.all(real_idxs): + weights = weights[real_idxs] + b_ary = b_ary[real_idxs] + # normalise weights + weights /= weights.sum() + + # posterior mean for b + b_post = np.sum(b_ary * weights) + # estimate for k + k_post = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member + # add prior for k_post + sigma = -k_post / b_post + k_post = (n * k_post + prior_k * 0.5) / (n + prior_k) + + return k_post, sigma + +def _gpinv(probs, kappa, sigma, mu): + """ + """ + if sigma <= 0: + return np.full_like(probs, np.nan) + + probs = 1 - probs + if kappa == 0: + q = mu - sigma * np.log1p(-probs) + else: + q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa + + return q + + From 621fc7dbf0d02e84fdc0bc65953128d4a7ef588d Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 9 Aug 2024 10:30:22 -0300 Subject: [PATCH 02/11] format --- src/arviz_stats/base/khat.py | 58 ++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py index 26ad042..4f36c3c 100644 --- a/src/arviz_stats/base/khat.py +++ b/src/arviz_stats/base/khat.py @@ -1,11 +1,13 @@ import warnings + import numpy as np from arviz import ess + def pareto_khat(x, r_eff=None, tail="both", log_weights=False): """ - - parameters + + Parameters ---------- x : DataArray """ @@ -20,46 +22,52 @@ def pareto_khat(x, r_eff=None, tail="both", log_weights=False): r_eff = ess(x.values, method="tail") / ndraws if ndraws > 255: - ndraws_tail = np.ceil(3 * (ndraws / r_eff)**0.5).astype(int) + ndraws_tail = np.ceil(3 * (ndraws / r_eff) ** 0.5).astype(int) else: ndraws_tail = int(ndraws / 5) if tail == "both": if ndraws_tail > ndraws / 2: - warnings.warn("Number of tail draws cannot be more than half " - "the total number of draws if both tails are fit, " - f"changing to {ndraws / 2}") + warnings.warn( + "Number of tail draws cannot be more than half " + "the total number of draws if both tails are fit, " + f"changing to {ndraws / 2}" + ) ndraws_tail = ndraws / 2 - if ndraws_tail < 5: - warnings.warn("Number of tail draws cannot be less than 5. " - "Changing to 5") + warnings.warn("Number of tail draws cannot be less than 5. " "Changing to 5") ndraws_tail = 5 - k = max([pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] for t in ("left", "right")]) + k = max( + [ + pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] + for t in ("left", "right") + ] + ) else: _, k = pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=tail) - return k + def ps_min_ss(k): if k < 1: - return 10**(1 / (1 - max(0, k))) + return 10 ** (1 / (1 - max(0, k))) else: return np.inf -def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', log_weights=False): + +def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail="both", log_weights=False): if log_weights: x = x - np.max(x) - if tail not in ['right', 'left', 'both']: + if tail not in ["right", "left", "both"]: raise ValueError('tail must be one of "right", "left", or "both"') tail_ids = np.arange(ndraws - ndraws_tail, ndraws) - if tail == 'left': + if tail == "left": x = -x ordered = np.argsort(x) @@ -72,7 +80,7 @@ def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', if ndraws_tail >= 5: if abs(max_tail - min_tail) < np.finfo(float).tiny: - raise ValueError('All tail values are the same') + raise ValueError("All tail values are the same") if log_weights: draws_tail = np.exp(draws_tail) @@ -83,21 +91,21 @@ def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', if np.isfinite(k) and smooth_draws: p = (np.arange(0.5, ndraws_tail)) / ndraws_tail smoothed = _gpinv(p, k, sigma, cutoff) - + if log_weights: smoothed = np.log(smoothed) else: smoothed = None else: - raise ValueError('ndraws_tail must be at least 5') - + raise ValueError("ndraws_tail must be at least 5") + if smoothed is not None: smoothed[smoothed > max_tail] = max_tail - x[ordered[tail_ids]] = smoothed + x[ordered[tail_ids]] = smoothed - if tail == 'left': + if tail == "left": x = -x - + return x, k @@ -150,9 +158,9 @@ def _gpdfit(ary): return k_post, sigma + def _gpinv(probs, kappa, sigma, mu): - """ - """ + """ """ if sigma <= 0: return np.full_like(probs, np.nan) @@ -163,5 +171,3 @@ def _gpinv(probs, kappa, sigma, mu): q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa return q - - From fb3d4b42e55c0e5abcd718eedf5cba0390524fa1 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 9 Aug 2024 16:20:46 -0300 Subject: [PATCH 03/11] improve variable names and docstrings --- src/arviz_stats/base/khat.py | 149 ++++++++++++++++++++++++----------- 1 file changed, 102 insertions(+), 47 deletions(-) diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py index 4f36c3c..d99d010 100644 --- a/src/arviz_stats/base/khat.py +++ b/src/arviz_stats/base/khat.py @@ -1,84 +1,134 @@ +"""Pareto k-hat diagnostics.""" + import warnings import numpy as np from arviz import ess -def pareto_khat(x, r_eff=None, tail="both", log_weights=False): +def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): """ + Compute Pareto k-hat diagnostic. + + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) Parameters ---------- x : DataArray + r_eff : float, optional + Relative efficiency. Effective sample size divided the number of samples. + If not provided, it will be estimated from the data. + tail : srt, optional + Which tail to fit. Can be 'right', 'left', or 'both'. + log_weights : bool, optional + Whether dt represents log-weights. + + Returns + ------- + khat : float + Pareto k-hat value. """ - ary = x.values.flatten() + ary = dt.values.flatten() if log_weights: tail = "right" - ndraws = len(ary) + n_draws = len(ary) if r_eff is None: - r_eff = ess(x.values, method="tail") / ndraws + r_eff = ess(dt.values, method="tail") / n_draws - if ndraws > 255: - ndraws_tail = np.ceil(3 * (ndraws / r_eff) ** 0.5).astype(int) + if n_draws > 255: + n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int) else: - ndraws_tail = int(ndraws / 5) + n_draws_tail = int(n_draws / 5) if tail == "both": - if ndraws_tail > ndraws / 2: + if n_draws_tail > n_draws / 2: warnings.warn( "Number of tail draws cannot be more than half " "the total number of draws if both tails are fit, " - f"changing to {ndraws / 2}" + f"changing to {n_draws / 2}" ) - ndraws_tail = ndraws / 2 + n_draws_tail = n_draws / 2 - if ndraws_tail < 5: - warnings.warn("Number of tail draws cannot be less than 5. " "Changing to 5") - ndraws_tail = 5 + if n_draws_tail < 5: + warnings.warn("Number of tail draws cannot be less than 5. Changing to 5") + n_draws_tail = 5 - k = max( - [ - pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] - for t in ("left", "right") - ] + khat = max( + ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1] + for t in ("left", "right") ) else: - _, k = pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=tail) + _, khat = ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail) + + return khat - return k +def pareto_min_ss(k): + """ + Compute minimum effective sample size. + + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) -def ps_min_ss(k): + Parameters + ---------- + k : float + Pareto k-hat value. + """ if k < 1: return 10 ** (1 / (1 - max(0, k))) - else: - return np.inf + return np.inf -def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail="both", log_weights=False): + +def ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False): + """ + Estimate the tail of a distribution using the Generalized Pareto Distribution. + + Parameters + ---------- + x : array + 1D array. + n_draws : int + Number of draws. + n_draws_tail : int + Number of draws in the tail. + smooth_draws : bool, optional + Whether to smooth the tail. + tail : str, optional + Which tail to fit. Can be 'right', 'left', or 'both'. + log_weights : bool, optional + Whether x represents log-weights. + + Returns + ------- + ary : array + Array with smoothed tail values. + k : float + Estimated shape parameter. + """ if log_weights: - x = x - np.max(x) + ary = ary - np.max(ary) if tail not in ["right", "left", "both"]: raise ValueError('tail must be one of "right", "left", or "both"') - tail_ids = np.arange(ndraws - ndraws_tail, ndraws) + tail_ids = np.arange(n_draws - n_draws_tail, n_draws) if tail == "left": - x = -x + ary = -ary - ordered = np.argsort(x) - draws_tail = x[ordered[tail_ids]] + ordered = np.argsort(ary) + draws_tail = ary[ordered[tail_ids]] - cutoff = x[ordered[tail_ids[0] - 1]] # largest value smaller than tail values + cutoff = ary[ordered[tail_ids[0] - 1]] # largest value smaller than tail values max_tail = np.max(draws_tail) min_tail = np.min(draws_tail) - if ndraws_tail >= 5: + if n_draws_tail >= 5: if abs(max_tail - min_tail) < np.finfo(float).tiny: raise ValueError("All tail values are the same") @@ -86,35 +136,40 @@ def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail="both", draws_tail = np.exp(draws_tail) cutoff = np.exp(cutoff) - k, sigma = _gpdfit(draws_tail - cutoff) + khat, sigma = _gpdfit(draws_tail - cutoff) - if np.isfinite(k) and smooth_draws: - p = (np.arange(0.5, ndraws_tail)) / ndraws_tail - smoothed = _gpinv(p, k, sigma, cutoff) + if np.isfinite(khat) and smooth_draws: + p = (np.arange(0.5, n_draws_tail)) / n_draws_tail + smoothed = _gpinv(p, khat, sigma, cutoff) if log_weights: smoothed = np.log(smoothed) else: smoothed = None else: - raise ValueError("ndraws_tail must be at least 5") + raise ValueError("n_draws_tail must be at least 5") if smoothed is not None: smoothed[smoothed > max_tail] = max_tail - x[ordered[tail_ids]] = smoothed + ary[ordered[tail_ids]] = smoothed if tail == "left": - x = -x + ary = -ary - return x, k + return ary, khat def _gpdfit(ary): """Estimate the parameters for the Generalized Pareto Distribution (GPD). - Empirical Bayes estimate for the parameters of the generalized Pareto + Empirical Bayes estimate for the parameters (kappa, sigma) of the generalized Pareto distribution given the data. + The fit uses a prior for kappa to stabilize estimates for very small (effective) + sample sizes. The weakly informative prior is a Gaussian centered at 0.5. + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) + + Parameters ---------- ary: array @@ -122,7 +177,7 @@ def _gpdfit(ary): Returns ------- - k: float + kappa: float estimated shape parameter sigma: float estimated scale parameter @@ -151,16 +206,16 @@ def _gpdfit(ary): # posterior mean for b b_post = np.sum(b_ary * weights) # estimate for k - k_post = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member - # add prior for k_post - sigma = -k_post / b_post - k_post = (n * k_post + prior_k * 0.5) / (n + prior_k) + kappa = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member + # add prior for kappa + sigma = -kappa / b_post + kappa = (n * kappa + prior_k * 0.5) / (n + prior_k) - return k_post, sigma + return kappa, sigma def _gpinv(probs, kappa, sigma, mu): - """ """ + """Quantile function for generalized pareto distribution.""" if sigma <= 0: return np.full_like(probs, np.nan) From 9a77c1f6a57e4d8a9d40c1becf66fcc2debec083 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 9 Aug 2024 16:40:59 -0300 Subject: [PATCH 04/11] make pareto_min_ss more indepedent --- src/arviz_stats/base/array.py | 3 +++ src/arviz_stats/base/khat.py | 22 +++++++++++----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 97e3fa1..566da15 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -130,6 +130,9 @@ def mcse(self, ary, chain_axis=-2, draw_axis=-1, method="mean", prob=None): mcse_array = make_ufunc(mcse_func, n_output=1, n_input=1, n_dims=2, ravel=False) return mcse_array(ary, **func_kwargs) + def pareto_min_ss(self, ary): + """Compute minimum effective sample size.""" + def get_bins(self, ary, axes=-1): """Compute default bins.""" ary, axes = process_ary_axes(ary, axes) diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py index d99d010..06fbcbb 100644 --- a/src/arviz_stats/base/khat.py +++ b/src/arviz_stats/base/khat.py @@ -6,7 +6,7 @@ from arviz import ess -def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): +def pareto_khat(ary, r_eff=1, tail="both", log_weights=False): """ Compute Pareto k-hat diagnostic. @@ -14,10 +14,9 @@ def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): Parameters ---------- - x : DataArray + ary : Array r_eff : float, optional Relative efficiency. Effective sample size divided the number of samples. - If not provided, it will be estimated from the data. tail : srt, optional Which tail to fit. Can be 'right', 'left', or 'both'. log_weights : bool, optional @@ -28,16 +27,11 @@ def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): khat : float Pareto k-hat value. """ - ary = dt.values.flatten() - if log_weights: tail = "right" n_draws = len(ary) - if r_eff is None: - r_eff = ess(dt.values, method="tail") / n_draws - if n_draws > 255: n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int) else: @@ -66,7 +60,7 @@ def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): return khat -def pareto_min_ss(k): +def pareto_min_ss(ary): """ Compute minimum effective sample size. @@ -77,8 +71,14 @@ def pareto_min_ss(k): k : float Pareto k-hat value. """ - if k < 1: - return 10 ** (1 / (1 - max(0, k))) + ary_flatten = ary.flatten() + n_draws = len(ary_flatten) + r_eff = ess(ary, method="tail") / n_draws + + kappa = pareto_khat(ary_flatten, r_eff=r_eff, tail="both", log_weights=False) + + if kappa < 1: + return 10 ** (1 / (1 - max(0, kappa))) return np.inf From 592455f55068fbd1aaef21eefdaf5aeed1831327 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 8 Aug 2024 17:22:15 -0300 Subject: [PATCH 05/11] [WIP] add khat and ps_min_ss --- src/arviz_stats/base/khat.py | 167 +++++++++++++++++++++++++++++++++++ 1 file changed, 167 insertions(+) create mode 100644 src/arviz_stats/base/khat.py diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py new file mode 100644 index 0000000..26ad042 --- /dev/null +++ b/src/arviz_stats/base/khat.py @@ -0,0 +1,167 @@ +import warnings +import numpy as np +from arviz import ess + +def pareto_khat(x, r_eff=None, tail="both", log_weights=False): + """ + + parameters + ---------- + x : DataArray + """ + ary = x.values.flatten() + + if log_weights: + tail = "right" + + ndraws = len(ary) + + if r_eff is None: + r_eff = ess(x.values, method="tail") / ndraws + + if ndraws > 255: + ndraws_tail = np.ceil(3 * (ndraws / r_eff)**0.5).astype(int) + else: + ndraws_tail = int(ndraws / 5) + + if tail == "both": + if ndraws_tail > ndraws / 2: + warnings.warn("Number of tail draws cannot be more than half " + "the total number of draws if both tails are fit, " + f"changing to {ndraws / 2}") + ndraws_tail = ndraws / 2 + + + if ndraws_tail < 5: + warnings.warn("Number of tail draws cannot be less than 5. " + "Changing to 5") + ndraws_tail = 5 + + k = max([pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] for t in ("left", "right")]) + else: + _, k = pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=tail) + + + return k + +def ps_min_ss(k): + if k < 1: + return 10**(1 / (1 - max(0, k))) + else: + return np.inf + +def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', log_weights=False): + if log_weights: + x = x - np.max(x) + + if tail not in ['right', 'left', 'both']: + raise ValueError('tail must be one of "right", "left", or "both"') + + tail_ids = np.arange(ndraws - ndraws_tail, ndraws) + + if tail == 'left': + x = -x + + ordered = np.argsort(x) + draws_tail = x[ordered[tail_ids]] + + cutoff = x[ordered[tail_ids[0] - 1]] # largest value smaller than tail values + + max_tail = np.max(draws_tail) + min_tail = np.min(draws_tail) + + if ndraws_tail >= 5: + if abs(max_tail - min_tail) < np.finfo(float).tiny: + raise ValueError('All tail values are the same') + + if log_weights: + draws_tail = np.exp(draws_tail) + cutoff = np.exp(cutoff) + + k, sigma = _gpdfit(draws_tail - cutoff) + + if np.isfinite(k) and smooth_draws: + p = (np.arange(0.5, ndraws_tail)) / ndraws_tail + smoothed = _gpinv(p, k, sigma, cutoff) + + if log_weights: + smoothed = np.log(smoothed) + else: + smoothed = None + else: + raise ValueError('ndraws_tail must be at least 5') + + if smoothed is not None: + smoothed[smoothed > max_tail] = max_tail + x[ordered[tail_ids]] = smoothed + + if tail == 'left': + x = -x + + return x, k + + +def _gpdfit(ary): + """Estimate the parameters for the Generalized Pareto Distribution (GPD). + + Empirical Bayes estimate for the parameters of the generalized Pareto + distribution given the data. + + Parameters + ---------- + ary: array + sorted 1D data array + + Returns + ------- + k: float + estimated shape parameter + sigma: float + estimated scale parameter + """ + prior_bs = 3 + prior_k = 10 + n = len(ary) + m_est = 30 + int(n**0.5) + + b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5)) + b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1] + b_ary += 1 / ary[-1] + + k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1) # pylint: disable=no-member + len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1) + weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1) + + # remove negligible weights + real_idxs = weights >= 10 * np.finfo(float).eps + if not np.all(real_idxs): + weights = weights[real_idxs] + b_ary = b_ary[real_idxs] + # normalise weights + weights /= weights.sum() + + # posterior mean for b + b_post = np.sum(b_ary * weights) + # estimate for k + k_post = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member + # add prior for k_post + sigma = -k_post / b_post + k_post = (n * k_post + prior_k * 0.5) / (n + prior_k) + + return k_post, sigma + +def _gpinv(probs, kappa, sigma, mu): + """ + """ + if sigma <= 0: + return np.full_like(probs, np.nan) + + probs = 1 - probs + if kappa == 0: + q = mu - sigma * np.log1p(-probs) + else: + q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa + + return q + + From ac86f46c69bc0a1a28dd7026363a7baa07d10e78 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 9 Aug 2024 10:30:22 -0300 Subject: [PATCH 06/11] format --- src/arviz_stats/base/khat.py | 58 ++++++++++++++++++++---------------- 1 file changed, 32 insertions(+), 26 deletions(-) diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py index 26ad042..4f36c3c 100644 --- a/src/arviz_stats/base/khat.py +++ b/src/arviz_stats/base/khat.py @@ -1,11 +1,13 @@ import warnings + import numpy as np from arviz import ess + def pareto_khat(x, r_eff=None, tail="both", log_weights=False): """ - - parameters + + Parameters ---------- x : DataArray """ @@ -20,46 +22,52 @@ def pareto_khat(x, r_eff=None, tail="both", log_weights=False): r_eff = ess(x.values, method="tail") / ndraws if ndraws > 255: - ndraws_tail = np.ceil(3 * (ndraws / r_eff)**0.5).astype(int) + ndraws_tail = np.ceil(3 * (ndraws / r_eff) ** 0.5).astype(int) else: ndraws_tail = int(ndraws / 5) if tail == "both": if ndraws_tail > ndraws / 2: - warnings.warn("Number of tail draws cannot be more than half " - "the total number of draws if both tails are fit, " - f"changing to {ndraws / 2}") + warnings.warn( + "Number of tail draws cannot be more than half " + "the total number of draws if both tails are fit, " + f"changing to {ndraws / 2}" + ) ndraws_tail = ndraws / 2 - if ndraws_tail < 5: - warnings.warn("Number of tail draws cannot be less than 5. " - "Changing to 5") + warnings.warn("Number of tail draws cannot be less than 5. " "Changing to 5") ndraws_tail = 5 - k = max([pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] for t in ("left", "right")]) + k = max( + [ + pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] + for t in ("left", "right") + ] + ) else: _, k = pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=tail) - return k + def ps_min_ss(k): if k < 1: - return 10**(1 / (1 - max(0, k))) + return 10 ** (1 / (1 - max(0, k))) else: return np.inf -def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', log_weights=False): + +def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail="both", log_weights=False): if log_weights: x = x - np.max(x) - if tail not in ['right', 'left', 'both']: + if tail not in ["right", "left", "both"]: raise ValueError('tail must be one of "right", "left", or "both"') tail_ids = np.arange(ndraws - ndraws_tail, ndraws) - if tail == 'left': + if tail == "left": x = -x ordered = np.argsort(x) @@ -72,7 +80,7 @@ def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', if ndraws_tail >= 5: if abs(max_tail - min_tail) < np.finfo(float).tiny: - raise ValueError('All tail values are the same') + raise ValueError("All tail values are the same") if log_weights: draws_tail = np.exp(draws_tail) @@ -83,21 +91,21 @@ def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail='both', if np.isfinite(k) and smooth_draws: p = (np.arange(0.5, ndraws_tail)) / ndraws_tail smoothed = _gpinv(p, k, sigma, cutoff) - + if log_weights: smoothed = np.log(smoothed) else: smoothed = None else: - raise ValueError('ndraws_tail must be at least 5') - + raise ValueError("ndraws_tail must be at least 5") + if smoothed is not None: smoothed[smoothed > max_tail] = max_tail - x[ordered[tail_ids]] = smoothed + x[ordered[tail_ids]] = smoothed - if tail == 'left': + if tail == "left": x = -x - + return x, k @@ -150,9 +158,9 @@ def _gpdfit(ary): return k_post, sigma + def _gpinv(probs, kappa, sigma, mu): - """ - """ + """ """ if sigma <= 0: return np.full_like(probs, np.nan) @@ -163,5 +171,3 @@ def _gpinv(probs, kappa, sigma, mu): q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa return q - - From 428246c092f4f451447786206b270652ed2faaaf Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 9 Aug 2024 16:20:46 -0300 Subject: [PATCH 07/11] improve variable names and docstrings --- src/arviz_stats/base/khat.py | 149 ++++++++++++++++++++++++----------- 1 file changed, 102 insertions(+), 47 deletions(-) diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py index 4f36c3c..d99d010 100644 --- a/src/arviz_stats/base/khat.py +++ b/src/arviz_stats/base/khat.py @@ -1,84 +1,134 @@ +"""Pareto k-hat diagnostics.""" + import warnings import numpy as np from arviz import ess -def pareto_khat(x, r_eff=None, tail="both", log_weights=False): +def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): """ + Compute Pareto k-hat diagnostic. + + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) Parameters ---------- x : DataArray + r_eff : float, optional + Relative efficiency. Effective sample size divided the number of samples. + If not provided, it will be estimated from the data. + tail : srt, optional + Which tail to fit. Can be 'right', 'left', or 'both'. + log_weights : bool, optional + Whether dt represents log-weights. + + Returns + ------- + khat : float + Pareto k-hat value. """ - ary = x.values.flatten() + ary = dt.values.flatten() if log_weights: tail = "right" - ndraws = len(ary) + n_draws = len(ary) if r_eff is None: - r_eff = ess(x.values, method="tail") / ndraws + r_eff = ess(dt.values, method="tail") / n_draws - if ndraws > 255: - ndraws_tail = np.ceil(3 * (ndraws / r_eff) ** 0.5).astype(int) + if n_draws > 255: + n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int) else: - ndraws_tail = int(ndraws / 5) + n_draws_tail = int(n_draws / 5) if tail == "both": - if ndraws_tail > ndraws / 2: + if n_draws_tail > n_draws / 2: warnings.warn( "Number of tail draws cannot be more than half " "the total number of draws if both tails are fit, " - f"changing to {ndraws / 2}" + f"changing to {n_draws / 2}" ) - ndraws_tail = ndraws / 2 + n_draws_tail = n_draws / 2 - if ndraws_tail < 5: - warnings.warn("Number of tail draws cannot be less than 5. " "Changing to 5") - ndraws_tail = 5 + if n_draws_tail < 5: + warnings.warn("Number of tail draws cannot be less than 5. Changing to 5") + n_draws_tail = 5 - k = max( - [ - pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=t)[1] - for t in ("left", "right") - ] + khat = max( + ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1] + for t in ("left", "right") ) else: - _, k = pareto_smooth_tail(ary, ndraws, ndraws_tail, smooth_draws=False, tail=tail) + _, khat = ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail) + + return khat - return k +def pareto_min_ss(k): + """ + Compute minimum effective sample size. + + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) -def ps_min_ss(k): + Parameters + ---------- + k : float + Pareto k-hat value. + """ if k < 1: return 10 ** (1 / (1 - max(0, k))) - else: - return np.inf + return np.inf -def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail="both", log_weights=False): + +def ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False): + """ + Estimate the tail of a distribution using the Generalized Pareto Distribution. + + Parameters + ---------- + x : array + 1D array. + n_draws : int + Number of draws. + n_draws_tail : int + Number of draws in the tail. + smooth_draws : bool, optional + Whether to smooth the tail. + tail : str, optional + Which tail to fit. Can be 'right', 'left', or 'both'. + log_weights : bool, optional + Whether x represents log-weights. + + Returns + ------- + ary : array + Array with smoothed tail values. + k : float + Estimated shape parameter. + """ if log_weights: - x = x - np.max(x) + ary = ary - np.max(ary) if tail not in ["right", "left", "both"]: raise ValueError('tail must be one of "right", "left", or "both"') - tail_ids = np.arange(ndraws - ndraws_tail, ndraws) + tail_ids = np.arange(n_draws - n_draws_tail, n_draws) if tail == "left": - x = -x + ary = -ary - ordered = np.argsort(x) - draws_tail = x[ordered[tail_ids]] + ordered = np.argsort(ary) + draws_tail = ary[ordered[tail_ids]] - cutoff = x[ordered[tail_ids[0] - 1]] # largest value smaller than tail values + cutoff = ary[ordered[tail_ids[0] - 1]] # largest value smaller than tail values max_tail = np.max(draws_tail) min_tail = np.min(draws_tail) - if ndraws_tail >= 5: + if n_draws_tail >= 5: if abs(max_tail - min_tail) < np.finfo(float).tiny: raise ValueError("All tail values are the same") @@ -86,35 +136,40 @@ def pareto_smooth_tail(x, ndraws, ndraws_tail, smooth_draws=False, tail="both", draws_tail = np.exp(draws_tail) cutoff = np.exp(cutoff) - k, sigma = _gpdfit(draws_tail - cutoff) + khat, sigma = _gpdfit(draws_tail - cutoff) - if np.isfinite(k) and smooth_draws: - p = (np.arange(0.5, ndraws_tail)) / ndraws_tail - smoothed = _gpinv(p, k, sigma, cutoff) + if np.isfinite(khat) and smooth_draws: + p = (np.arange(0.5, n_draws_tail)) / n_draws_tail + smoothed = _gpinv(p, khat, sigma, cutoff) if log_weights: smoothed = np.log(smoothed) else: smoothed = None else: - raise ValueError("ndraws_tail must be at least 5") + raise ValueError("n_draws_tail must be at least 5") if smoothed is not None: smoothed[smoothed > max_tail] = max_tail - x[ordered[tail_ids]] = smoothed + ary[ordered[tail_ids]] = smoothed if tail == "left": - x = -x + ary = -ary - return x, k + return ary, khat def _gpdfit(ary): """Estimate the parameters for the Generalized Pareto Distribution (GPD). - Empirical Bayes estimate for the parameters of the generalized Pareto + Empirical Bayes estimate for the parameters (kappa, sigma) of the generalized Pareto distribution given the data. + The fit uses a prior for kappa to stabilize estimates for very small (effective) + sample sizes. The weakly informative prior is a Gaussian centered at 0.5. + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) + + Parameters ---------- ary: array @@ -122,7 +177,7 @@ def _gpdfit(ary): Returns ------- - k: float + kappa: float estimated shape parameter sigma: float estimated scale parameter @@ -151,16 +206,16 @@ def _gpdfit(ary): # posterior mean for b b_post = np.sum(b_ary * weights) # estimate for k - k_post = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member - # add prior for k_post - sigma = -k_post / b_post - k_post = (n * k_post + prior_k * 0.5) / (n + prior_k) + kappa = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member + # add prior for kappa + sigma = -kappa / b_post + kappa = (n * kappa + prior_k * 0.5) / (n + prior_k) - return k_post, sigma + return kappa, sigma def _gpinv(probs, kappa, sigma, mu): - """ """ + """Quantile function for generalized pareto distribution.""" if sigma <= 0: return np.full_like(probs, np.nan) From 5cf4fb1cf28666e314d5f5a6e4f76c60bd9379d1 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Fri, 9 Aug 2024 16:40:59 -0300 Subject: [PATCH 08/11] make pareto_min_ss more indepedent --- src/arviz_stats/base/khat.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/khat.py index d99d010..06fbcbb 100644 --- a/src/arviz_stats/base/khat.py +++ b/src/arviz_stats/base/khat.py @@ -6,7 +6,7 @@ from arviz import ess -def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): +def pareto_khat(ary, r_eff=1, tail="both", log_weights=False): """ Compute Pareto k-hat diagnostic. @@ -14,10 +14,9 @@ def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): Parameters ---------- - x : DataArray + ary : Array r_eff : float, optional Relative efficiency. Effective sample size divided the number of samples. - If not provided, it will be estimated from the data. tail : srt, optional Which tail to fit. Can be 'right', 'left', or 'both'. log_weights : bool, optional @@ -28,16 +27,11 @@ def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): khat : float Pareto k-hat value. """ - ary = dt.values.flatten() - if log_weights: tail = "right" n_draws = len(ary) - if r_eff is None: - r_eff = ess(dt.values, method="tail") / n_draws - if n_draws > 255: n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int) else: @@ -66,7 +60,7 @@ def pareto_khat(dt, r_eff=None, tail="both", log_weights=False): return khat -def pareto_min_ss(k): +def pareto_min_ss(ary): """ Compute minimum effective sample size. @@ -77,8 +71,14 @@ def pareto_min_ss(k): k : float Pareto k-hat value. """ - if k < 1: - return 10 ** (1 / (1 - max(0, k))) + ary_flatten = ary.flatten() + n_draws = len(ary_flatten) + r_eff = ess(ary, method="tail") / n_draws + + kappa = pareto_khat(ary_flatten, r_eff=r_eff, tail="both", log_weights=False) + + if kappa < 1: + return 10 ** (1 / (1 - max(0, kappa))) return np.inf From 7c4b792c2edca3fdbe28096126f38fcd473e7101 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 12 Sep 2024 11:50:42 -0300 Subject: [PATCH 09/11] add accesors --- src/arviz_stats/accessors.py | 12 +++++++++++ src/arviz_stats/base/array.py | 9 +++++++- src/arviz_stats/base/dataarray.py | 11 ++++++++++ src/arviz_stats/base/diagnostics.py | 15 +++++++++++++ src/arviz_stats/base/{khat.py => pareto.py} | 24 --------------------- 5 files changed, 46 insertions(+), 25 deletions(-) rename src/arviz_stats/base/{khat.py => pareto.py} (91%) diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index c5eaef7..3bbe40c 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -55,6 +55,10 @@ def thin(self, factor="auto", dims=None, **kwargs): """Perform thinning on the DataArray.""" return get_function("thin")(self._obj, factor=factor, dims=dims, **kwargs) + def pareto_min_ss(self): + """Compute the minimum effective sample size on the DataArray.""" + return get_function("pareto_min_ss")(self._obj) + @xr.register_dataset_accessor("azstats") class AzStatsDsAccessor(_BaseAccessor): @@ -147,6 +151,10 @@ def thin(self, dims=None, factor="auto"): """Perform thinning for all the variables in the dataset.""" return self._apply(get_function("thin"), dims=dims, factor=factor) + def pareto_min_ss(self, dims=None): + """Compute the min sample size for all variables in the dataset.""" + return self._apply("pareto_min_ss", dims=dims) + @register_datatree_accessor("azstats") class AzStatsDtAccessor(_BaseAccessor): @@ -215,3 +223,7 @@ def histogram(self, dims=None, group="posterior", **kwargs): def thin(self, dims=None, group="posterior", **kwargs): """Perform thinning for all variables in a group of the DataTree.""" return self._apply("thin", dims=dims, group=group, **kwargs) + + def pareto_min_ss(self, dims=None, group="posterior"): + """Compute the min sample size for all variables in a group of the DataTree.""" + return self._apply("pareto_min_ss", dims=dims, group=group) diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index f916e5f..07fcd1f 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -132,8 +132,15 @@ def mcse(self, ary, chain_axis=-2, draw_axis=-1, method="mean", prob=None): mcse_array = make_ufunc(mcse_func, n_output=1, n_input=1, n_dims=2, ravel=False) return mcse_array(ary, **func_kwargs) - def pareto_min_ss(self, ary): + def pareto_min_ss(self, ary, chain_axis=-2, draw_axis=-1): """Compute minimum effective sample size.""" + if chain_axis is None: + ary = np.expand_dims(ary, axis=0) + chain_axis = 0 + ary, _ = process_ary_axes(ary, [chain_axis, draw_axis]) + pms_func = getattr(self, "_pareto_min_ss") + pms_array = make_ufunc(pms_func, n_output=1, n_input=1, n_dims=2, ravel=False) + return pms_array(ary) def compute_ranks(self, ary, axes=-1, relative=False): """Compute ranks of MCMC samples.""" diff --git a/src/arviz_stats/base/dataarray.py b/src/arviz_stats/base/dataarray.py index 3e3354d..ec10d8f 100644 --- a/src/arviz_stats/base/dataarray.py +++ b/src/arviz_stats/base/dataarray.py @@ -240,5 +240,16 @@ def thin(self, da, factor="auto", dims=None): return da.sel({dims: slice(None, None, factor)}) + def pareto_min_ss(self, da, dims=None): + """Compute the minimum effective sample size for all variables in the dataset.""" + if dims is None: + dims = rcParams["data.sample_dims"] + return apply_ufunc( + self.array_class.pareto_min_ss, + da, + input_core_dims=[dims], + output_core_dims=[[]], + ) + dataarray_stats = BaseDataArray(array_class=array_stats) diff --git a/src/arviz_stats/base/diagnostics.py b/src/arviz_stats/base/diagnostics.py index 625584a..8fcd1b5 100644 --- a/src/arviz_stats/base/diagnostics.py +++ b/src/arviz_stats/base/diagnostics.py @@ -7,6 +7,7 @@ from scipy import stats from arviz_stats.base.core import _CoreBase +from arviz_stats.base.pareto import pareto_khat from arviz_stats.base.stats_utils import not_valid as _not_valid @@ -327,3 +328,17 @@ def _mcse_quantile(self, ary, prob): th1 = sorted_ary[int(np.floor(np.nanmax((ppf_size[0], 0))))] th2 = sorted_ary[int(np.ceil(np.nanmin((ppf_size[1], size - 1))))] return (th2 - th1) / 2 + + def _pareto_min_ss(self, ary): + """Compute the minimum effective sample size.""" + ary = np.asarray(ary) + ary_flatten = ary.flatten() + n_draws = len(ary_flatten) + r_eff = self._ess_tail(ary) / n_draws + + kappa = pareto_khat(ary_flatten, r_eff=r_eff, tail="both", log_weights=False) + + if kappa < 1: + return 10 ** (1 / (1 - max(0, kappa))) + + return np.inf diff --git a/src/arviz_stats/base/khat.py b/src/arviz_stats/base/pareto.py similarity index 91% rename from src/arviz_stats/base/khat.py rename to src/arviz_stats/base/pareto.py index 06fbcbb..3a58788 100644 --- a/src/arviz_stats/base/khat.py +++ b/src/arviz_stats/base/pareto.py @@ -3,7 +3,6 @@ import warnings import numpy as np -from arviz import ess def pareto_khat(ary, r_eff=1, tail="both", log_weights=False): @@ -60,29 +59,6 @@ def pareto_khat(ary, r_eff=1, tail="both", log_weights=False): return khat -def pareto_min_ss(ary): - """ - Compute minimum effective sample size. - - See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) - - Parameters - ---------- - k : float - Pareto k-hat value. - """ - ary_flatten = ary.flatten() - n_draws = len(ary_flatten) - r_eff = ess(ary, method="tail") / n_draws - - kappa = pareto_khat(ary_flatten, r_eff=r_eff, tail="both", log_weights=False) - - if kappa < 1: - return 10 ** (1 / (1 - max(0, kappa))) - - return np.inf - - def ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False): """ Estimate the tail of a distribution using the Generalized Pareto Distribution. From d0d40ff914125b99fbef1fa53e183a311f547944 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 12 Sep 2024 12:18:35 -0300 Subject: [PATCH 10/11] set min draws and chains --- src/arviz_stats/base/diagnostics.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/arviz_stats/base/diagnostics.py b/src/arviz_stats/base/diagnostics.py index 8fcd1b5..9c4b81b 100644 --- a/src/arviz_stats/base/diagnostics.py +++ b/src/arviz_stats/base/diagnostics.py @@ -332,6 +332,8 @@ def _mcse_quantile(self, ary, prob): def _pareto_min_ss(self, ary): """Compute the minimum effective sample size.""" ary = np.asarray(ary) + if _not_valid(ary, shape_kwargs={"min_draws": 4, "min_chains": 1}): + return np.nan ary_flatten = ary.flatten() n_draws = len(ary_flatten) r_eff = self._ess_tail(ary) / n_draws From d26d3c076f04054562f9ffa2acd3177b3379e961 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Wed, 18 Sep 2024 12:24:55 -0300 Subject: [PATCH 11/11] fix per comments --- src/arviz_stats/accessors.py | 4 +- src/arviz_stats/base/array.py | 3 +- src/arviz_stats/base/diagnostics.py | 206 +++++++++++++++++++++++++++- src/arviz_stats/base/pareto.py | 204 --------------------------- 4 files changed, 205 insertions(+), 212 deletions(-) delete mode 100644 src/arviz_stats/base/pareto.py diff --git a/src/arviz_stats/accessors.py b/src/arviz_stats/accessors.py index 3bbe40c..e0272b7 100644 --- a/src/arviz_stats/accessors.py +++ b/src/arviz_stats/accessors.py @@ -55,9 +55,9 @@ def thin(self, factor="auto", dims=None, **kwargs): """Perform thinning on the DataArray.""" return get_function("thin")(self._obj, factor=factor, dims=dims, **kwargs) - def pareto_min_ss(self): + def pareto_min_ss(self, dims=None): """Compute the minimum effective sample size on the DataArray.""" - return get_function("pareto_min_ss")(self._obj) + return get_function("pareto_min_ss")(self._obj, dims=dims) @xr.register_dataset_accessor("azstats") diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 07fcd1f..050d322 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -138,8 +138,7 @@ def pareto_min_ss(self, ary, chain_axis=-2, draw_axis=-1): ary = np.expand_dims(ary, axis=0) chain_axis = 0 ary, _ = process_ary_axes(ary, [chain_axis, draw_axis]) - pms_func = getattr(self, "_pareto_min_ss") - pms_array = make_ufunc(pms_func, n_output=1, n_input=1, n_dims=2, ravel=False) + pms_array = make_ufunc(self._pareto_min_ss, n_output=1, n_input=1, n_dims=2, ravel=False) return pms_array(ary) def compute_ranks(self, ary, axes=-1, relative=False): diff --git a/src/arviz_stats/base/diagnostics.py b/src/arviz_stats/base/diagnostics.py index 9c4b81b..da71b52 100644 --- a/src/arviz_stats/base/diagnostics.py +++ b/src/arviz_stats/base/diagnostics.py @@ -1,13 +1,13 @@ # pylint: disable=too-many-lines, too-many-function-args, redefined-outer-name """Diagnostic functions for ArviZ.""" +import warnings from collections.abc import Sequence import numpy as np from scipy import stats from arviz_stats.base.core import _CoreBase -from arviz_stats.base.pareto import pareto_khat from arviz_stats.base.stats_utils import not_valid as _not_valid @@ -335,12 +335,210 @@ def _pareto_min_ss(self, ary): if _not_valid(ary, shape_kwargs={"min_draws": 4, "min_chains": 1}): return np.nan ary_flatten = ary.flatten() - n_draws = len(ary_flatten) - r_eff = self._ess_tail(ary) / n_draws + r_eff = self._ess_tail(ary, relative=True) - kappa = pareto_khat(ary_flatten, r_eff=r_eff, tail="both", log_weights=False) + kappa = self._pareto_khat(ary_flatten, r_eff=r_eff, tail="both", log_weights=False) if kappa < 1: return 10 ** (1 / (1 - max(0, kappa))) return np.inf + + def _pareto_khat(self, ary, r_eff=1, tail="both", log_weights=False): + """ + Compute Pareto k-hat diagnostic. + + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) + + Parameters + ---------- + ary : Array + r_eff : float, optional + Relative efficiency. Effective sample size divided the number of samples. + tail : srt, optional + Which tail to fit. Can be 'right', 'left', or 'both'. + log_weights : bool, optional + Whether dt represents log-weights. + + Returns + ------- + khat : float + Pareto k-hat value. + """ + if log_weights: + tail = "right" + + n_draws = len(ary) + + if n_draws > 255: + n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int) + else: + n_draws_tail = int(n_draws / 5) + + if tail == "both": + if n_draws_tail > n_draws / 2: + warnings.warn( + "Number of tail draws cannot be more than half " + "the total number of draws if both tails are fit, " + f"changing to {n_draws / 2}" + ) + n_draws_tail = n_draws / 2 + + if n_draws_tail < 5: + warnings.warn("Number of tail draws cannot be less than 5. Changing to 5") + n_draws_tail = 5 + + khat = max( + self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1] + for t in ("left", "right") + ) + else: + _, khat = self._ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail) + + return khat + + def _ps_tail( + self, ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False + ): + """ + Estimate the tail of a distribution using the Generalized Pareto Distribution. + + Parameters + ---------- + x : array + 1D array. + n_draws : int + Number of draws. + n_draws_tail : int + Number of draws in the tail. + smooth_draws : bool, optional + Whether to smooth the tail. + tail : str, optional + Which tail to fit. Can be 'right', 'left', or 'both'. + log_weights : bool, optional + Whether x represents log-weights. + + Returns + ------- + ary : array + Array with smoothed tail values. + k : float + Estimated shape parameter. + """ + if log_weights: + ary = ary - np.max(ary) + + if tail not in ["right", "left", "both"]: + raise ValueError('tail must be one of "right", "left", or "both"') + + tail_ids = np.arange(n_draws - n_draws_tail, n_draws) + + if tail == "left": + ary = -ary + + ordered = np.argsort(ary) + draws_tail = ary[ordered[tail_ids]] + + cutoff = ary[ordered[tail_ids[0] - 1]] # largest value smaller than tail values + + max_tail = np.max(draws_tail) + min_tail = np.min(draws_tail) + + if n_draws_tail >= 5: + if abs(max_tail - min_tail) < np.finfo(float).tiny: + raise ValueError("All tail values are the same") + + if log_weights: + draws_tail = np.exp(draws_tail) + cutoff = np.exp(cutoff) + + khat, sigma = self._gpdfit(draws_tail - cutoff) + + if np.isfinite(khat) and smooth_draws: + p = (np.arange(0.5, n_draws_tail)) / n_draws_tail + smoothed = self._gpinv(p, khat, sigma, cutoff) + + if log_weights: + smoothed = np.log(smoothed) + else: + smoothed = None + else: + raise ValueError("n_draws_tail must be at least 5") + + if smoothed is not None: + smoothed[smoothed > max_tail] = max_tail + ary[ordered[tail_ids]] = smoothed + + if tail == "left": + ary = -ary + + return ary, khat + + @staticmethod + def _gpdfit(ary): + """Estimate the parameters for the Generalized Pareto Distribution (GPD). + + Empirical Bayes estimate for the parameters (kappa, sigma) of the generalized Pareto + distribution given the data. + + The fit uses a prior for kappa to stabilize estimates for very small (effective) + sample sizes. The weakly informative prior is a Gaussian centered at 0.5. + See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) + + + Parameters + ---------- + ary: array + sorted 1D data array + + Returns + ------- + kappa: float + estimated shape parameter + sigma: float + estimated scale parameter + """ + prior_bs = 3 + prior_k = 10 + n = len(ary) + m_est = 30 + int(n**0.5) + + b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5)) + b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1] + b_ary += 1 / ary[-1] + + k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1) # pylint: disable=no-member + len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1) + weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1) + + # remove negligible weights + real_idxs = weights >= 10 * np.finfo(float).eps + if not np.all(real_idxs): + weights = weights[real_idxs] + b_ary = b_ary[real_idxs] + # normalise weights + weights /= weights.sum() + + # posterior mean for b + b_post = np.sum(b_ary * weights) + # estimate for k + kappa = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member + # add prior for kappa + sigma = -kappa / b_post + kappa = (n * kappa + prior_k * 0.5) / (n + prior_k) + + return kappa, sigma + + @staticmethod + def _gpinv(probs, kappa, sigma, mu): + """Quantile function for generalized pareto distribution.""" + if sigma <= 0: + return np.full_like(probs, np.nan) + + probs = 1 - probs + if kappa == 0: + q = mu - sigma * np.log1p(-probs) + else: + q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa + + return q diff --git a/src/arviz_stats/base/pareto.py b/src/arviz_stats/base/pareto.py deleted file mode 100644 index 3a58788..0000000 --- a/src/arviz_stats/base/pareto.py +++ /dev/null @@ -1,204 +0,0 @@ -"""Pareto k-hat diagnostics.""" - -import warnings - -import numpy as np - - -def pareto_khat(ary, r_eff=1, tail="both", log_weights=False): - """ - Compute Pareto k-hat diagnostic. - - See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) - - Parameters - ---------- - ary : Array - r_eff : float, optional - Relative efficiency. Effective sample size divided the number of samples. - tail : srt, optional - Which tail to fit. Can be 'right', 'left', or 'both'. - log_weights : bool, optional - Whether dt represents log-weights. - - Returns - ------- - khat : float - Pareto k-hat value. - """ - if log_weights: - tail = "right" - - n_draws = len(ary) - - if n_draws > 255: - n_draws_tail = np.ceil(3 * (n_draws / r_eff) ** 0.5).astype(int) - else: - n_draws_tail = int(n_draws / 5) - - if tail == "both": - if n_draws_tail > n_draws / 2: - warnings.warn( - "Number of tail draws cannot be more than half " - "the total number of draws if both tails are fit, " - f"changing to {n_draws / 2}" - ) - n_draws_tail = n_draws / 2 - - if n_draws_tail < 5: - warnings.warn("Number of tail draws cannot be less than 5. Changing to 5") - n_draws_tail = 5 - - khat = max( - ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=t)[1] - for t in ("left", "right") - ) - else: - _, khat = ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail=tail) - - return khat - - -def ps_tail(ary, n_draws, n_draws_tail, smooth_draws=False, tail="both", log_weights=False): - """ - Estimate the tail of a distribution using the Generalized Pareto Distribution. - - Parameters - ---------- - x : array - 1D array. - n_draws : int - Number of draws. - n_draws_tail : int - Number of draws in the tail. - smooth_draws : bool, optional - Whether to smooth the tail. - tail : str, optional - Which tail to fit. Can be 'right', 'left', or 'both'. - log_weights : bool, optional - Whether x represents log-weights. - - Returns - ------- - ary : array - Array with smoothed tail values. - k : float - Estimated shape parameter. - """ - if log_weights: - ary = ary - np.max(ary) - - if tail not in ["right", "left", "both"]: - raise ValueError('tail must be one of "right", "left", or "both"') - - tail_ids = np.arange(n_draws - n_draws_tail, n_draws) - - if tail == "left": - ary = -ary - - ordered = np.argsort(ary) - draws_tail = ary[ordered[tail_ids]] - - cutoff = ary[ordered[tail_ids[0] - 1]] # largest value smaller than tail values - - max_tail = np.max(draws_tail) - min_tail = np.min(draws_tail) - - if n_draws_tail >= 5: - if abs(max_tail - min_tail) < np.finfo(float).tiny: - raise ValueError("All tail values are the same") - - if log_weights: - draws_tail = np.exp(draws_tail) - cutoff = np.exp(cutoff) - - khat, sigma = _gpdfit(draws_tail - cutoff) - - if np.isfinite(khat) and smooth_draws: - p = (np.arange(0.5, n_draws_tail)) / n_draws_tail - smoothed = _gpinv(p, khat, sigma, cutoff) - - if log_weights: - smoothed = np.log(smoothed) - else: - smoothed = None - else: - raise ValueError("n_draws_tail must be at least 5") - - if smoothed is not None: - smoothed[smoothed > max_tail] = max_tail - ary[ordered[tail_ids]] = smoothed - - if tail == "left": - ary = -ary - - return ary, khat - - -def _gpdfit(ary): - """Estimate the parameters for the Generalized Pareto Distribution (GPD). - - Empirical Bayes estimate for the parameters (kappa, sigma) of the generalized Pareto - distribution given the data. - - The fit uses a prior for kappa to stabilize estimates for very small (effective) - sample sizes. The weakly informative prior is a Gaussian centered at 0.5. - See details in Vehtari et al., 2024 (https://doi.org/10.48550/arXiv.1507.02646) - - - Parameters - ---------- - ary: array - sorted 1D data array - - Returns - ------- - kappa: float - estimated shape parameter - sigma: float - estimated scale parameter - """ - prior_bs = 3 - prior_k = 10 - n = len(ary) - m_est = 30 + int(n**0.5) - - b_ary = 1 - np.sqrt(m_est / (np.arange(1, m_est + 1, dtype=float) - 0.5)) - b_ary /= prior_bs * ary[int(n / 4 + 0.5) - 1] - b_ary += 1 / ary[-1] - - k_ary = np.log1p(-b_ary[:, None] * ary).mean(axis=1) # pylint: disable=no-member - len_scale = n * (np.log(-(b_ary / k_ary)) - k_ary - 1) - weights = 1 / np.exp(len_scale - len_scale[:, None]).sum(axis=1) - - # remove negligible weights - real_idxs = weights >= 10 * np.finfo(float).eps - if not np.all(real_idxs): - weights = weights[real_idxs] - b_ary = b_ary[real_idxs] - # normalise weights - weights /= weights.sum() - - # posterior mean for b - b_post = np.sum(b_ary * weights) - # estimate for k - kappa = np.log1p(-b_post * ary).mean() # pylint: disable=invalid-unary-operand-type,no-member - # add prior for kappa - sigma = -kappa / b_post - kappa = (n * kappa + prior_k * 0.5) / (n + prior_k) - - return kappa, sigma - - -def _gpinv(probs, kappa, sigma, mu): - """Quantile function for generalized pareto distribution.""" - if sigma <= 0: - return np.full_like(probs, np.nan) - - probs = 1 - probs - if kappa == 0: - q = mu - sigma * np.log1p(-probs) - else: - q = mu + sigma * np.expm1(-kappa * np.log1p(-probs)) / kappa - - return q