Skip to content

Commit

Permalink
Improvements to multimodal HDI (#28)
Browse files Browse the repository at this point in the history
* Fix typo

* Refactor multimodal HDI code

More modular functions and vectorization

* Ensure interval contains >=hdi_prob

* For integer/bool HDI, default to bin width of 1

* Split continuous and discrete multimodal HDI

* Default to ISJ bandwidth for multimodal HDI

* Return highest probability modes

* Fix bugs in circular KDE

* Support circular continuous multimodal HDI

* Assume input probabilities sum to 1

* Merge lines

* Scale KDE density to bin probabilities

* Use bins returned by `_histogram`

* Avoid duplication of HDI defaults

* Fix and test passing bins to discrete multimodal

* Simplify HDI nearest code

* Fix circular standardization

* Correctly compute bin centers

* Fix pylint issues

* Move interval splitting to own function

* Use circular standardization

* Add method for computing HDI from point densities

* Add multimodal_nearest HDI method

* rename and add check for warning in tests

---------

Co-authored-by: Oriol (ProDesk) <[email protected]>
  • Loading branch information
sethaxen and OriolAbril authored Oct 26, 2024
1 parent 5b110df commit c5a4f5f
Show file tree
Hide file tree
Showing 4 changed files with 255 additions and 76 deletions.
34 changes: 25 additions & 9 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,16 +47,24 @@ def hdi(
circular=False,
max_modes=10,
skipna=False,
**kwargs,
):
"""Compute of HDI function on array-like input."""
"""Compute HDI function on array-like input."""
if not 1 >= prob > 0:
raise ValueError("The value of `prob` must be in the (0, 1] interval.")
if method == "multimodal" and circular:
raise ValueError("Multimodal hdi not supported for circular data.")
ary, axes = process_ary_axes(ary, axes)
is_discrete = np.issubdtype(ary.dtype, np.integer) or np.issubdtype(ary.dtype, np.bool_)
is_multimodal = method.startswith("multimodal")
if is_multimodal and circular and is_discrete:
raise ValueError("Multimodal hdi not supported for discrete circular data.")
hdi_func = {
"nearest": self._hdi_nearest,
"multimodal": self._hdi_multimodal,
"multimodal": (
self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous
),
"multimodal_sample": (
self._hdi_multimodal_discrete if is_discrete else self._hdi_multimodal_continuous
),
}[method]
hdi_array = make_ufunc(
hdi_func,
Expand All @@ -67,15 +75,23 @@ def hdi(
func_kwargs = {
"prob": prob,
"skipna": skipna,
"out_shape": (max_modes, 2) if method == "multimodal" else (2,),
"out_shape": (max_modes, 2) if is_multimodal else (2,),
"circular": circular,
}
if method != "multimodal":
func_kwargs["circular"] = circular
else:
if is_multimodal:
func_kwargs["max_modes"] = max_modes
if is_discrete:
func_kwargs.pop("circular")
func_kwargs.pop("skipna")
else:
func_kwargs["bw"] = "isj" if not circular else "taylor"
func_kwargs.update(kwargs)

if method == "multimodal_sample":
func_kwargs["from_sample"] = True

result = hdi_array(ary, **func_kwargs)
if method == "multimodal":
if is_multimodal:
mode_mask = [np.all(np.isnan(result[..., i, :])) for i in range(result.shape[-2])]
result = result[..., ~np.array(mode_mask), :]
return result
Expand Down
188 changes: 133 additions & 55 deletions src/arviz_stats/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,10 @@ def circular_mean(self, ary): # pylint: disable=no-self-use
"""
return circmean(ary, high=np.pi, low=-np.pi)

def _circular_standardize(self, ary): # pylint: disable=no-self-use
"""Standardize circular data to the interval [-pi, pi]."""
return np.mod(ary + np.pi, 2 * np.pi) - np.pi

def quantile(self, ary, quantile, **kwargs): # pylint: disable=no-self-use
"""Compute the quantile of an array of samples.
Expand Down Expand Up @@ -226,20 +230,9 @@ def _histogram(self, ary, bins=None, range=None, weights=None, density=None):
bins = self._get_bins(ary)
return np.histogram(ary, bins=bins, range=range, weights=weights, density=density)

def _hdi_linear_nearest_common(self, ary, prob, skipna, circular):
ary = ary.flatten()
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]
def _hdi_linear_nearest_common(self, ary, prob): # pylint: disable=no-self-use
n = len(ary)

mean = None
if circular:
mean = self.circular_mean(ary)
ary = ary - mean
ary = np.arctan2(np.sin(ary), np.cos(ary))

ary = np.sort(ary)
interval_idx_inc = int(np.floor(prob * n))
n_intervals = n - interval_idx_inc
Expand All @@ -249,62 +242,147 @@ def _hdi_linear_nearest_common(self, ary, prob, skipna, circular):
raise ValueError("Too few elements for interval calculation. ")

min_idx = np.argmin(interval_width)
hdi_interval = ary[[min_idx, min_idx + interval_idx_inc]]

return ary, mean, min_idx, interval_idx_inc
return hdi_interval

def _hdi_nearest(self, ary, prob, circular, skipna):
"""Compute HDI over the flattened array as closest samples that contain the given prob."""
ary, mean, min_idx, interval_idx_inc = self._hdi_linear_nearest_common(
ary, prob, skipna, circular
)

hdi_min = ary[min_idx]
hdi_max = ary[min_idx + interval_idx_inc]
ary = ary.flatten()
if skipna:
nans = np.isnan(ary)
if not nans.all():
ary = ary[~nans]

if circular:
hdi_min = hdi_min + mean
hdi_max = hdi_max + mean
hdi_min = np.arctan2(np.sin(hdi_min), np.cos(hdi_min))
hdi_max = np.arctan2(np.sin(hdi_max), np.cos(hdi_max))
mean = self.circular_mean(ary)
ary = self._circular_standardize(ary - mean)

hdi_interval = self._hdi_linear_nearest_common(ary, prob)

hdi_interval = np.array([hdi_min, hdi_max])
if circular:
hdi_interval = self._circular_standardize(hdi_interval + mean)

return hdi_interval

def _hdi_multimodal(self, ary, prob, skipna, max_modes):
def _hdi_multimodal_continuous(
self, ary, prob, skipna, max_modes, circular, from_sample=False, **kwargs
):
"""Compute HDI if the distribution is multimodal."""
ary = ary.flatten()
if skipna:
ary = ary[~np.isnan(ary)]

if ary.dtype.kind == "f":
bins, density, _ = self.kde(ary)
lower, upper = bins[0], bins[-1]
range_x = upper - lower
dx = range_x / len(density)
bins, density, _ = self.kde(ary, circular=circular, **kwargs)
if from_sample:
ary_density = np.interp(ary, bins, density)
hdi_intervals, interval_probs = self._hdi_from_point_densities(
ary, ary_density, prob, circular
)
else:
bins = self._get_bins(ary)
density, _ = self._histogram(ary, bins=bins, density=True)
dx = np.diff(bins)[0]

density *= dx

idx = np.argsort(-density)
intervals = bins[idx][density[idx].cumsum() <= prob]
intervals.sort()

intervals_splitted = np.split(intervals, np.where(np.diff(intervals) >= dx * 1.1)[0] + 1)

hdi_intervals = np.full((max_modes, 2), np.nan)
for i, interval in enumerate(intervals_splitted):
if i == max_modes:
warnings.warn(
f"found more modes than {max_modes}, returning only the first {max_modes} modes"
)
break
if interval.size == 0:
hdi_intervals[i] = np.asarray([bins[0], bins[0]])
else:
hdi_intervals[i] = np.asarray([interval[0], interval[-1]])

return np.array(hdi_intervals)
dx = (bins[-1] - bins[0]) / (len(bins) - 1)
bin_probs = density * dx

hdi_intervals, interval_probs = self._hdi_from_bin_probabilities(
bins, bin_probs, prob, circular, dx
)

return self._pad_hdi_to_maxmodes(hdi_intervals, interval_probs, max_modes)

def _hdi_multimodal_discrete(self, ary, prob, max_modes, bins=None):
"""Compute HDI if the distribution is multimodal."""
ary = ary.flatten()

if bins is None:
bins, counts = np.unique(ary, return_counts=True)
bin_probs = counts / len(ary)
dx = 1
else:
counts, edges = self._histogram(ary, bins=bins)
bins = 0.5 * (edges[1:] + edges[:-1])
bin_probs = counts / counts.sum()
dx = bins[1] - bins[0]

hdi_intervals, interval_probs = self._hdi_from_bin_probabilities(
bins, bin_probs, prob, False, dx
)

return self._pad_hdi_to_maxmodes(hdi_intervals, interval_probs, max_modes)

def _hdi_from_point_densities(self, points, densities, prob, circular):
if circular:
points = self._circular_standardize(points)

sorted_idx = np.argsort(points)
points = points[sorted_idx]
densities = densities[sorted_idx]

# find idx of points in the interval
interval_size = int(np.ceil(prob * len(points)))
sorted_idx = np.argsort(densities)[::-1]
idx_in_interval = sorted_idx[:interval_size]
idx_in_interval.sort()

# find idx of interval bounds
probs_in_interval = np.full(idx_in_interval.shape, 1 / len(points))
interval_bounds_idx, interval_probs = self._interval_points_to_bounds(
idx_in_interval, probs_in_interval, 1, circular, period=len(points)
)

return points[interval_bounds_idx], interval_probs

def _hdi_from_bin_probabilities(self, bins, bin_probs, prob, circular, dx):
if circular:
bins = self._circular_standardize(bins)
sorted_idx = np.argsort(bins)
bins = bins[sorted_idx]
bin_probs = bin_probs[sorted_idx]

# find idx of bins in the interval
sorted_idx = np.argsort(bin_probs)[::-1]
cum_probs = bin_probs[sorted_idx].cumsum()
interval_size = np.searchsorted(cum_probs, prob, side="left") + 1
idx_in_interval = sorted_idx[:interval_size]
idx_in_interval.sort()

# get points in intervals
intervals = bins[idx_in_interval]
probs_in_interval = bin_probs[idx_in_interval]

return self._interval_points_to_bounds(intervals, probs_in_interval, dx, circular)

def _interval_points_to_bounds(self, points, probs, dx, circular, period=2 * np.pi): # pylint: disable=no-self-use
cum_probs = probs.cumsum()

is_bound = np.diff(points) > dx * 1.01
is_lower_bound = np.insert(is_bound, 0, True)
is_upper_bound = np.append(is_bound, True)
interval_bounds = np.column_stack([points[is_lower_bound], points[is_upper_bound]])
interval_probs = (
cum_probs[is_upper_bound] - cum_probs[is_lower_bound] + probs[is_lower_bound]
)

if (
circular
and np.mod(dx * 1.01 + interval_bounds[-1, -1] - interval_bounds[0, 0], period)
<= dx * 1.01
):
interval_bounds[-1, 1] = interval_bounds[0, 1]
interval_bounds = interval_bounds[1:, :]
interval_probs[-1] += interval_probs[0]
interval_probs = interval_probs[1:]

return interval_bounds, interval_probs

def _pad_hdi_to_maxmodes(self, hdi_intervals, interval_probs, max_modes): # pylint: disable=no-self-use
if hdi_intervals.shape[0] > max_modes:
warnings.warn(
f"found more modes than {max_modes}, returning only the {max_modes} highest "
"probability modes"
)
hdi_intervals = hdi_intervals[np.argsort(interval_probs)[::-1][:max_modes], :]
elif hdi_intervals.shape[0] < max_modes:
hdi_intervals = np.vstack(
[hdi_intervals, np.full((max_modes - hdi_intervals.shape[0], 2), np.nan)]
)
return hdi_intervals
12 changes: 4 additions & 8 deletions src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,7 @@ def eti(self, da, prob=None, dims=None, method="linear"):
kwargs={"axis": np.arange(-len(dims), 0, 1), "method": method},
)

def hdi(
self, da, prob=None, dims=None, method="nearest", circular=False, max_modes=10, skipna=False
):
def hdi(self, da, prob=None, dims=None, method="nearest", **kwargs):
"""Compute hdi on DataArray input."""
dims = validate_dims(dims)
prob = validate_ci_prob(prob)
Expand All @@ -48,13 +46,11 @@ def hdi(
da,
prob,
input_core_dims=[dims, []],
output_core_dims=[[mode_dim, "hdi"] if method == "multimodal" else ["hdi"]],
output_core_dims=[[mode_dim, "hdi"] if method.startswith("multimodal") else ["hdi"]],
kwargs={
"method": method,
"circular": circular,
"skipna": skipna,
"max_modes": max_modes,
"axes": np.arange(-len(dims), 0, 1),
"method": method,
**kwargs,
},
).assign_coords({"hdi": hdi_coord})

Expand Down
Loading

0 comments on commit c5a4f5f

Please sign in to comment.