diff --git a/src/arviz_stats/base/array.py b/src/arviz_stats/base/array.py index 00c7c32..583e71f 100644 --- a/src/arviz_stats/base/array.py +++ b/src/arviz_stats/base/array.py @@ -47,16 +47,24 @@ def hdi( circular=False, max_modes=10, skipna=False, + **kwargs, ): - """Compute of HDI function on array-like input.""" + """Compute HDI function on array-like input.""" if not 1 >= prob > 0: raise ValueError("The value of `prob` must be in the (0, 1] interval.") - if method == "multimodal" and circular: - raise ValueError("Multimodal hdi not supported for circular data.") ary, axes = process_ary_axes(ary, axes) + is_discrete = np.issubdtype(ary.dtype, np.integer) or np.issubdtype(ary.dtype, np.bool_) + is_multimodal = method.startswith("multimodal") + if is_multimodal and circular and is_discrete: + raise ValueError("Multimodal hdi not supported for discrete circular data.") hdi_func = { "nearest": self._hdi_nearest, - "multimodal": self._hdi_multimodal, + "multimodal": ( + self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous + ), + "multimodal_sample": ( + self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous + ), }[method] hdi_array = make_ufunc( hdi_func, @@ -67,15 +75,23 @@ def hdi( func_kwargs = { "prob": prob, "skipna": skipna, - "out_shape": (max_modes, 2) if method == "multimodal" else (2,), + "out_shape": (max_modes, 2) if is_multimodal else (2,), + "circular": circular, } - if method != "multimodal": - func_kwargs["circular"] = circular - else: + if is_multimodal: func_kwargs["max_modes"] = max_modes + if is_discrete: + func_kwargs.pop("circular") + func_kwargs.pop("skipna") + else: + func_kwargs["bw"] = "isj" if not circular else "taylor" + func_kwargs.update(kwargs) + + if method == "multimodal_sample": + func_kwargs["from_sample"] = True result = hdi_array(ary, **func_kwargs) - if method == "multimodal": + if is_multimodal: mode_mask = [np.all(np.isnan(result[..., i, :])) for i in range(result.shape[-2])] result = result[..., ~np.array(mode_mask), :] return result diff --git a/src/arviz_stats/base/core.py b/src/arviz_stats/base/core.py index 6c83bc2..b0becdd 100644 --- a/src/arviz_stats/base/core.py +++ b/src/arviz_stats/base/core.py @@ -86,6 +86,10 @@ def circular_mean(self, ary): # pylint: disable=no-self-use """ return circmean(ary, high=np.pi, low=-np.pi) + def _circular_standardize(self, ary): # pylint: disable=no-self-use + """Standardize circular data to the interval [-pi, pi].""" + return np.mod(ary + np.pi, 2 * np.pi) - np.pi + def quantile(self, ary, quantile, **kwargs): # pylint: disable=no-self-use """Compute the quantile of an array of samples. @@ -226,20 +230,9 @@ def _histogram(self, ary, bins=None, range=None, weights=None, density=None): bins = self._get_bins(ary) return np.histogram(ary, bins=bins, range=range, weights=weights, density=density) - def _hdi_linear_nearest_common(self, ary, prob, skipna, circular): - ary = ary.flatten() - if skipna: - nans = np.isnan(ary) - if not nans.all(): - ary = ary[~nans] + def _hdi_linear_nearest_common(self, ary, prob): # pylint: disable=no-self-use n = len(ary) - mean = None - if circular: - mean = self.circular_mean(ary) - ary = ary - mean - ary = np.arctan2(np.sin(ary), np.cos(ary)) - ary = np.sort(ary) interval_idx_inc = int(np.floor(prob * n)) n_intervals = n - interval_idx_inc @@ -249,62 +242,147 @@ def _hdi_linear_nearest_common(self, ary, prob, skipna, circular): raise ValueError("Too few elements for interval calculation. ") min_idx = np.argmin(interval_width) + hdi_interval = ary[[min_idx, min_idx + interval_idx_inc]] - return ary, mean, min_idx, interval_idx_inc + return hdi_interval def _hdi_nearest(self, ary, prob, circular, skipna): """Compute HDI over the flattened array as closest samples that contain the given prob.""" - ary, mean, min_idx, interval_idx_inc = self._hdi_linear_nearest_common( - ary, prob, skipna, circular - ) - - hdi_min = ary[min_idx] - hdi_max = ary[min_idx + interval_idx_inc] + ary = ary.flatten() + if skipna: + nans = np.isnan(ary) + if not nans.all(): + ary = ary[~nans] if circular: - hdi_min = hdi_min + mean - hdi_max = hdi_max + mean - hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min)) - hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max)) + mean = self.circular_mean(ary) + ary = self._circular_standardize(ary - mean) + + hdi_interval = self._hdi_linear_nearest_common(ary, prob) - hdi_interval = np.array([hdi_min, hdi_max]) + if circular: + hdi_interval = self._circular_standardize(hdi_interval + mean) return hdi_interval - def _hdi_multimodal(self, ary, prob, skipna, max_modes): + def _hdi_multimodal_continuous( + self, ary, prob, skipna, max_modes, circular, from_sample=False, **kwargs + ): """Compute HDI if the distribution is multimodal.""" ary = ary.flatten() if skipna: ary = ary[~np.isnan(ary)] - if ary.dtype.kind == "f": - bins, density, _ = self.kde(ary) - lower, upper = bins[0], bins[-1] - range_x = upper - lower - dx = range_x / len(density) + bins, density, _ = self.kde(ary, circular=circular, **kwargs) + if from_sample: + ary_density = np.interp(ary, bins, density) + hdi_intervals, interval_probs = self._hdi_from_point_densities( + ary, ary_density, prob, circular + ) else: - bins = self._get_bins(ary) - density, _ = self._histogram(ary, bins=bins, density=True) - dx = np.diff(bins)[0] - - density *= dx - - idx = np.argsort(-density) - intervals = bins[idx][density[idx].cumsum() <= prob] - intervals.sort() - - intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1) - - hdi_intervals = np.full((max_modes, 2), np.nan) - for i, interval in enumerate(intervals_splitted): - if i == max_modes: - warnings.warn( - f"found more modes than {max_modes}, returning only the first {max_modes} modes" - ) - break - if interval.size == 0: - hdi_intervals[i] = np.asarray([bins[0], bins[0]]) - else: - hdi_intervals[i] = np.asarray([interval[0], interval[-1]]) - - return np.array(hdi_intervals) + dx = (bins[-1] - bins[0]) / (len(bins) - 1) + bin_probs = density * dx + + hdi_intervals, interval_probs = self._hdi_from_bin_probabilities( + bins, bin_probs, prob, circular, dx + ) + + return self._pad_hdi_to_maxmodes(hdi_intervals, interval_probs, max_modes) + + def _hdi_multimodal_discrete(self, ary, prob, max_modes, bins=None): + """Compute HDI if the distribution is multimodal.""" + ary = ary.flatten() + + if bins is None: + bins, counts = np.unique(ary, return_counts=True) + bin_probs = counts / len(ary) + dx = 1 + else: + counts, edges = self._histogram(ary, bins=bins) + bins = 0.5 * (edges[1:] + edges[:-1]) + bin_probs = counts / counts.sum() + dx = bins[1] - bins[0] + + hdi_intervals, interval_probs = self._hdi_from_bin_probabilities( + bins, bin_probs, prob, False, dx + ) + + return self._pad_hdi_to_maxmodes(hdi_intervals, interval_probs, max_modes) + + def _hdi_from_point_densities(self, points, densities, prob, circular): + if circular: + points = self._circular_standardize(points) + + sorted_idx = np.argsort(points) + points = points[sorted_idx] + densities = densities[sorted_idx] + + # find idx of points in the interval + interval_size = int(np.ceil(prob * len(points))) + sorted_idx = np.argsort(densities)[::-1] + idx_in_interval = sorted_idx[:interval_size] + idx_in_interval.sort() + + # find idx of interval bounds + probs_in_interval = np.full(idx_in_interval.shape, 1 / len(points)) + interval_bounds_idx, interval_probs = self._interval_points_to_bounds( + idx_in_interval, probs_in_interval, 1, circular, period=len(points) + ) + + return points[interval_bounds_idx], interval_probs + + def _hdi_from_bin_probabilities(self, bins, bin_probs, prob, circular, dx): + if circular: + bins = self._circular_standardize(bins) + sorted_idx = np.argsort(bins) + bins = bins[sorted_idx] + bin_probs = bin_probs[sorted_idx] + + # find idx of bins in the interval + sorted_idx = np.argsort(bin_probs)[::-1] + cum_probs = bin_probs[sorted_idx].cumsum() + interval_size = np.searchsorted(cum_probs, prob, side="left") + 1 + idx_in_interval = sorted_idx[:interval_size] + idx_in_interval.sort() + + # get points in intervals + intervals = bins[idx_in_interval] + probs_in_interval = bin_probs[idx_in_interval] + + return self._interval_points_to_bounds(intervals, probs_in_interval, dx, circular) + + def _interval_points_to_bounds(self, points, probs, dx, circular, period=2 * np.pi): # pylint: disable=no-self-use + cum_probs = probs.cumsum() + + is_bound = np.diff(points) > dx * 1.01 + is_lower_bound = np.insert(is_bound, 0, True) + is_upper_bound = np.append(is_bound, True) + interval_bounds = np.column_stack([points[is_lower_bound], points[is_upper_bound]]) + interval_probs = ( + cum_probs[is_upper_bound] - cum_probs[is_lower_bound] + probs[is_lower_bound] + ) + + if ( + circular + and np.mod(dx * 1.01 + interval_bounds[-1, -1] - interval_bounds[0, 0], period) + <= dx * 1.01 + ): + interval_bounds[-1, 1] = interval_bounds[0, 1] + interval_bounds = interval_bounds[1:, :] + interval_probs[-1] += interval_probs[0] + interval_probs = interval_probs[1:] + + return interval_bounds, interval_probs + + def _pad_hdi_to_maxmodes(self, hdi_intervals, interval_probs, max_modes): # pylint: disable=no-self-use + if hdi_intervals.shape[0] > max_modes: + warnings.warn( + f"found more modes than {max_modes}, returning only the {max_modes} highest " + "probability modes" + ) + hdi_intervals = hdi_intervals[np.argsort(interval_probs)[::-1][:max_modes], :] + elif hdi_intervals.shape[0] < max_modes: + hdi_intervals = np.vstack( + [hdi_intervals, np.full((max_modes - hdi_intervals.shape[0], 2), np.nan)] + ) + return hdi_intervals diff --git a/src/arviz_stats/base/dataarray.py b/src/arviz_stats/base/dataarray.py index c5b0815..73b0c4c 100644 --- a/src/arviz_stats/base/dataarray.py +++ b/src/arviz_stats/base/dataarray.py @@ -34,9 +34,7 @@ def eti(self, da, prob=None, dims=None, method="linear"): kwargs={"axis": np.arange(-len(dims), 0, 1), "method": method}, ) - def hdi( - self, da, prob=None, dims=None, method="nearest", circular=False, max_modes=10, skipna=False - ): + def hdi(self, da, prob=None, dims=None, method="nearest", **kwargs): """Compute hdi on DataArray input.""" dims = validate_dims(dims) prob = validate_ci_prob(prob) @@ -48,13 +46,11 @@ def hdi( da, prob, input_core_dims=[dims, []], - output_core_dims=[[mode_dim, "hdi"] if method == "multimodal" else ["hdi"]], + output_core_dims=[[mode_dim, "hdi"] if method.startswith("multimodal") else ["hdi"]], kwargs={ - "method": method, - "circular": circular, - "skipna": skipna, - "max_modes": max_modes, "axes": np.arange(-len(dims), 0, 1), + "method": method, + **kwargs, }, ).assign_coords({"hdi": hdi_coord}) diff --git a/tests/base/test_stats.py b/tests/base/test_stats.py index 3dc194b..925ba1f 100644 --- a/tests/base/test_stats.py +++ b/tests/base/test_stats.py @@ -3,6 +3,7 @@ import pytest from arviz_base import load_arviz_data, ndarray_to_dataarray, rc_context from numpy.testing import assert_array_almost_equal, assert_array_equal +from scipy.stats import bernoulli, norm, poisson from xarray import DataArray, Dataset @@ -88,15 +89,67 @@ def test_hdi_coords(centered_eight): assert_array_equal(result.coords["chain"], [0, 1, 3]) -def test_hdi_multimodal(): +@pytest.mark.parametrize("prob", [0.56, 0.83]) +@pytest.mark.parametrize("nearest", [True, False]) +def test_hdi_multimodal_continuous(prob, nearest): + method = "multimodal_sample" if nearest else "multimodal" rng = np.random.default_rng(43) + dist1 = norm(loc=-30, scale=0.5) + dist2 = norm(loc=30, scale=0.5) normal_sample = ndarray_to_dataarray( - np.concatenate((rng.normal(-4, 1, 2500000), rng.normal(2, 0.5, 2500000))), + np.concatenate( + (dist1.rvs(2500000, random_state=rng), dist2.rvs(2500000, random_state=rng)) + ), "x", sample_dims=["sample"], ) - intervals = normal_sample.azstats.hdi(dims="sample", method="multimodal", prob=0.83) - assert_array_almost_equal(intervals, [[-5.1, -2.8], [1.1, 2.8]], 1) + exact_hdis = np.concatenate( + [ + np.array(dist1.interval(prob))[np.newaxis, :], + np.array(dist2.interval(prob))[np.newaxis, :], + ], + axis=0, + ) + intervals = normal_sample.azstats.hdi(dims="sample", method=method, prob=prob) + assert_array_almost_equal(intervals, exact_hdis, 1) + + if nearest: + assert np.all(np.isin(intervals, normal_sample)) + else: + assert not np.any(np.isin(intervals, normal_sample)) + + +@pytest.mark.parametrize("prob", [0.56, 0.83]) +def test_hdi_multimodal_discrete(prob): + rng = np.random.default_rng(43) + dist1 = poisson(10) + dist2 = poisson(100) + x = np.concatenate([dist1.rvs(2500000, random_state=rng), dist2.rvs(2500000, random_state=rng)]) + sample = ndarray_to_dataarray(x, "x", sample_dims=["sample"]) + intervals = sample.azstats.hdi(dims="sample", method="multimodal", prob=prob) + assert intervals.sizes["mode"] == 2 + lower_mode = intervals.sel(mode=0) + higher_mode = intervals.sel(mode=1) + assert lower_mode[0] <= 10 <= lower_mode[1] + assert higher_mode[0] <= 100 <= higher_mode[1] + + # restrict the bins to a range in which only a single mode will appear + bins = np.arange(0, 20) - 0.5 + intervals = sample.azstats.hdi(dims="sample", method="multimodal", prob=prob, bins=bins) + assert intervals.sizes["mode"] == 1 + assert intervals.sel(mode=0)[0] <= 10 <= intervals.sel(mode=0)[1] + + +@pytest.mark.parametrize("prob", [0.56, 0.83]) +@pytest.mark.parametrize("dist", [poisson(10), bernoulli(0.7)]) +def test_hdi_multimodal_unimodal_discrete_consistent(dist, prob): + rng = np.random.default_rng(43) + x = dist.rvs(size=1_000, random_state=rng) + sample = ndarray_to_dataarray(x, "x", sample_dims=["sample"]) + intervals = sample.azstats.hdi(dims="sample", method="multimodal", prob=prob) + intervals_unimodal = sample.azstats.hdi(dims="sample", method="nearest", prob=prob) + intervals = intervals.squeeze("mode") + assert_array_equal(intervals, intervals_unimodal) def test_hdi_multimodal_multivars(): @@ -117,6 +170,38 @@ def test_hdi_multimodal_multivars(): assert "var2_mode" in intervals.var2.dims +def test_hdi_multimodal_max_modes(): + rng = np.random.default_rng(42) + x = np.concatenate([rng.normal(0, 1, 250_000), rng.normal(30, 1, 2_500_000)]) + sample = ndarray_to_dataarray(x, "x", sample_dims=["sample"]) + intervals = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9) + assert intervals.sizes["mode"] == 2 + with pytest.warns(UserWarning, match="found more modes"): + intervals2 = sample.azstats.hdi(dims="sample", method="multimodal", prob=0.9, max_modes=1) + assert intervals2.sizes["mode"] == 1 + assert intervals2.equals(intervals.isel(mode=[1])) + + +@pytest.mark.parametrize("nearest", [True, False]) +def test_hdi_multimodal_circular(nearest): + rng = np.random.default_rng(43) + normal_sample = ndarray_to_dataarray( + np.concatenate( + [ + rng.vonmises(np.pi, 10, 2500000), + rng.vonmises(np.pi / 2, 10, 2500000), + ] + ), + "x", + sample_dims=["sample"], + ) + method = "multimodal_sample" if nearest else "multimodal" + interval = normal_sample.azstats.hdi(circular=True, method=method, prob=0.83, dims="sample") + assert interval.sizes["mode"] == 2 + assert interval.sel(mode=0)[0] <= np.pi / 2 <= interval.sel(mode=0)[1] + assert interval.sel(mode=1)[0] <= np.pi and interval.sel(mode=1)[1] >= -np.pi + + def test_hdi_circular(): rng = np.random.default_rng(43) normal_sample = ndarray_to_dataarray( @@ -124,6 +209,10 @@ def test_hdi_circular(): ) interval = normal_sample.azstats.hdi(circular=True, prob=0.83, dims="sample") assert_array_almost_equal(interval, [1.3, -1.4], 1) + interval_multi = normal_sample.azstats.hdi( + circular=True, prob=0.83, dims="sample", method="multimodal" + ) + assert_array_almost_equal(interval_multi, [[1.3, -1.4]], 1) def test_hdi_bad_ci():