Skip to content

Commit

Permalink
add compute_ranks function (#17)
Browse files Browse the repository at this point in the history
* add compute_ranks function

* add xr compatible version

* add relative option

* some fixes

* lint fixes

* don't log warning for arrays full of nans

* Apply suggestions from code review

Co-authored-by: Osvaldo A Martin <[email protected]>

* remove unnecessary list comprehension

* update docs

---------

Co-authored-by: Osvaldo A Martin <[email protected]>
  • Loading branch information
OriolAbril and aloctavodia authored Aug 30, 2024
1 parent ad46a84 commit d3f313b
Show file tree
Hide file tree
Showing 7 changed files with 83 additions and 86 deletions.
2 changes: 2 additions & 0 deletions docs/source/api/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ Currently, using accessors is the recommended way to call functions from `arviz_
xarray.Dataset.azstats.filter_vars
xarray.Dataset.azstats.eti
xarray.Dataset.azstats.hdi
xarray.Dataset.azstats.compute_ranks
xarray.Dataset.azstats.ess
xarray.Dataset.azstats.rhat
xarray.Dataset.azstats.mcse
xarray.Dataset.azstats.thin
xarray.Dataset.azstats.kde
xarray.Dataset.azstats.histogram
xarray.Dataset.azstats.ecdf
Expand Down
4 changes: 4 additions & 0 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,10 @@ def histogram(self, dims=None, **kwargs):
"""Compute the KDE for all variables in the dataset."""
return self._apply("histogram", dims=dims, **kwargs)

def compute_ranks(self, dims=None, relative=False):
"""Compute ranks for all variables in the dataset."""
return self._apply("compute_ranks", dims=dims, relative=relative)

def ecdf(self, dims=None, **kwargs):
"""Compute the ecdf for all variables in the dataset."""
# TODO: implement ecdf here so it doesn't depend on numba
Expand Down
14 changes: 14 additions & 0 deletions src/arviz_stats/base/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def process_ary_axes(ary, axes):
ary : array_like
axes : int or sequence of int
"""
if axes is None:
axes = list(range(ary.ndim))
if isinstance(axes, int):
axes = [axes]
axes = [ax if ax >= 0 else ary.ndim + ax for ax in axes]
Expand Down Expand Up @@ -130,6 +132,18 @@ def mcse(self, ary, chain_axis=-2, draw_axis=-1, method="mean", prob=None):
mcse_array = make_ufunc(mcse_func, n_output=1, n_input=1, n_dims=2, ravel=False)
return mcse_array(ary, **func_kwargs)

def compute_ranks(self, ary, axes=-1, relative=False):
"""Compute ranks of MCMC samples."""
ary, axes = process_ary_axes(ary, axes)
compute_ranks_ufunc = make_ufunc(
self._compute_ranks,
n_output=1,
n_input=1,
n_dims=len(axes),
ravel=False,
)
return compute_ranks_ufunc(ary, out_shape=(ary.shape[i] for i in axes), relative=relative)

def get_bins(self, ary, axes=-1):
"""Compute default bins."""
ary, axes = process_ary_axes(ary, axes)
Expand Down
29 changes: 29 additions & 0 deletions src/arviz_stats/base/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import numpy as np
from scipy.fftpack import next_fast_len
from scipy.interpolate import CubicSpline
from scipy.stats import circmean


Expand Down Expand Up @@ -108,6 +109,34 @@ def eti(self, ary, prob, **kwargs):
edge_prob = (1 - prob) / 2
return self.quantile(ary, [edge_prob, 1 - edge_prob], **kwargs)

def _float_rankdata(self, ary): # pylint: disable=no-self-use
"""Compute ranks on continuous data, assuming there are no ties.
Notes
-----
:func:`scipy.stats.rankdata` is focused on discrete data and different ways
to resolve ties. However, our most common use is converting all data to continuous
to get rid of the ties, the call rankdata which is not very efficient nor
numba compatible.
"""
ranks = np.empty(len(ary), dtype=int)
ranks[np.argsort(ary, axis=None)] = np.arange(1, ary.size + 1)
return ranks

def _compute_ranks(self, ary, relative=False):
"""Compute ranks for continuous and discrete variables."""
ary_shape = ary.shape
ary = ary.flatten()
if ary.dtype.kind == "i":
min_ary, max_ary = min(ary), max(ary)
x = np.linspace(min_ary, max_ary, len(ary))
csi = CubicSpline(x, ary)
ary = csi(np.linspace(min_ary + 0.001, max_ary - 0.001, len(ary))).reshape(ary_shape)
out = self._float_rankdata(ary).reshape(ary_shape)
if relative:
return out / out.size
return out

def _get_bininfo(self, values):
dtype = values.dtype.kind

Expand Down
14 changes: 14 additions & 0 deletions src/arviz_stats/base/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,20 @@ def ess(self, da, dims=None, method="bulk", relative=False, prob=None):
kwargs={"method": method, "relative": relative, "prob": prob},
)

def compute_ranks(self, da, dims=None, relative=False):
"""Compute ranks on DataArray input."""
if dims is None:
dims = rcParams["data.sample_dims"]
if isinstance(dims, str):
dims = [dims]
return apply_ufunc(
self.array_class.compute_ranks,
da,
input_core_dims=[dims],
output_core_dims=[dims],
kwargs={"relative": relative},
)

def rhat(self, da, dims=None, method="bulk"):
"""Compute rhat on DataArray input."""
if dims is None:
Expand Down
62 changes: 7 additions & 55 deletions src/arviz_stats/base/stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@

import numpy as np
from scipy.interpolate import CubicSpline
from xarray import apply_ufunc

__all__ = ["make_ufunc", "wrap_xarray_ufunc"]
__all__ = ["make_ufunc"]

_log = logging.getLogger(__name__)

Expand Down Expand Up @@ -125,58 +124,6 @@ def _multi_ufunc(*args, out=None, out_shape=None, shape_from_1st=False, **kwargs
return ufunc


def wrap_xarray_ufunc(
ufunc,
*datasets,
ufunc_kwargs=None,
func_args=None,
func_kwargs=None,
**kwargs,
):
"""Wrap make_ufunc with xarray.apply_ufunc.
Parameters
----------
ufunc : callable
*datasets : xarray.Dataset
ufunc_kwargs : dict
Keyword arguments passed to `make_ufunc`.
- 'n_dims', int, by default 2
- 'n_output', int, by default 1
- 'n_input', int, by default len(datasets)
- 'index', slice, by default Ellipsis
- 'ravel', bool, by default True
func_args : tuple
Arguments passed to 'ufunc'.
func_kwargs : dict
Keyword arguments passed to 'ufunc'.
- 'out_shape', int, by default None
**kwargs
Passed to :func:`xarray.apply_ufunc`.
Returns
-------
xarray.Dataset
"""
if ufunc_kwargs is None:
ufunc_kwargs = {}
ufunc_kwargs.setdefault("n_input", len(datasets))
if func_args is None:
func_args = tuple()
if func_kwargs is None:
func_kwargs = {}

kwargs.setdefault(
"input_core_dims", tuple(("chain", "draw") for _ in range(len(func_args) + len(datasets)))
)
ufunc_kwargs.setdefault("n_dims", len(kwargs["input_core_dims"][-1]))
kwargs.setdefault("output_core_dims", tuple([] for _ in range(ufunc_kwargs.get("n_output", 1))))

callable_ufunc = make_ufunc(ufunc, **ufunc_kwargs)

return apply_ufunc(callable_ufunc, *datasets, *func_args, kwargs=func_kwargs, **kwargs)


def update_docstring(ufunc, func, n_output=1):
"""Update ArviZ generated ufunc docstring."""
module = ""
Expand Down Expand Up @@ -305,11 +252,16 @@ def not_valid(ary, check_nan=True, check_shape=True, nan_kwargs=None, shape_kwar
draw_error = False
chain_error = False

# for arviz-plots alignment, if all elements are nan return nan without indicating
# any error
isnan = np.isnan(ary)
if isnan.all():
return True

if check_nan:
if nan_kwargs is None:
nan_kwargs = {}

isnan = np.isnan(ary)
axis = nan_kwargs.get("axis", None)
if nan_kwargs.get("how", "any").lower() == "all":
nan_error = isnan.all(axis)
Expand Down
44 changes: 13 additions & 31 deletions tests/base/test_stats_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as np
import pytest
from arviz_stats.base.stats_utils import logsumexp as _logsumexp
from arviz_stats.base.stats_utils import make_ufunc, not_valid, wrap_xarray_ufunc
from arviz_stats.base.stats_utils import make_ufunc, not_valid
from numpy.testing import assert_array_almost_equal
from scipy.special import logsumexp

Expand Down Expand Up @@ -68,19 +68,15 @@ def test_logsumexp_b_inv(ary_dtype, axis, b_inv, keepdims):

@pytest.mark.parametrize("quantile", ((0.5,), (0.5, 0.1)))
@pytest.mark.parametrize("arg", (True, False))
def test_wrap_ufunc_output(quantile, arg):
def test_make_ufunc_output(quantile, arg):
ary = np.random.randn(4, 100)
n_output = len(quantile)
if arg:
res = wrap_xarray_ufunc(
np.quantile, ary, ufunc_kwargs={"n_output": n_output}, func_args=(quantile,)
)
res = make_ufunc(np.quantile, n_output=n_output)(ary, quantile)
elif n_output == 1:
res = wrap_xarray_ufunc(np.quantile, ary, func_kwargs={"q": quantile})
res = make_ufunc(np.quantile)(ary, q=quantile)
else:
res = wrap_xarray_ufunc(
np.quantile, ary, ufunc_kwargs={"n_output": n_output}, func_kwargs={"q": quantile}
)
res = make_ufunc(np.quantile, n_output=n_output)(ary, q=quantile)
if n_output == 1:
assert not isinstance(res, tuple)
else:
Expand All @@ -90,48 +86,34 @@ def test_wrap_ufunc_output(quantile, arg):

@pytest.mark.parametrize("out_shape", ((1, 2), (1, 2, 3), (2, 3, 4, 5)))
@pytest.mark.parametrize("input_dim", ((4, 100), (4, 100, 3), (4, 100, 4, 5)))
def test_wrap_ufunc_out_shape(out_shape, input_dim):
def test_make_ufunc_out_shape(out_shape, input_dim):
func = lambda x: np.random.rand(*out_shape)
ary = np.ones(input_dim)
res = wrap_xarray_ufunc(
func, ary, func_kwargs={"out_shape": out_shape}, ufunc_kwargs={"n_dims": 1}
)
res = make_ufunc(func, n_dims=1)(ary, out_shape=out_shape)
assert res.shape == (*ary.shape[:-1], *out_shape)


def test_wrap_ufunc_out_shape_multi_input():
def test_make_ufunc_out_shape_multi_input():
out_shape = (2, 4)
func = lambda x, y: np.random.rand(*out_shape)
ary1 = np.ones((4, 100))
ary2 = np.ones((4, 5))
res = wrap_xarray_ufunc(
func, ary1, ary2, func_kwargs={"out_shape": out_shape}, ufunc_kwargs={"n_dims": 1}
)
res = make_ufunc(func, n_dims=1)(ary1, ary2, out_shape=out_shape)
assert res.shape == (*ary1.shape[:-1], *out_shape)


def test_wrap_ufunc_out_shape_multi_output_same():
def test_make_ufunc_out_shape_multi_output_same():
func = lambda x: (np.random.rand(1, 2), np.random.rand(1, 2))
ary = np.ones((4, 100))
res1, res2 = wrap_xarray_ufunc(
func,
ary,
func_kwargs={"out_shape": ((1, 2), (1, 2))},
ufunc_kwargs={"n_dims": 1, "n_output": 2},
)
res1, res2 = make_ufunc(func, n_dims=1, n_output=2)(ary, out_shape=((1, 2), (1, 2)))
assert res1.shape == (*ary.shape[:-1], 1, 2)
assert res2.shape == (*ary.shape[:-1], 1, 2)


def test_wrap_ufunc_out_shape_multi_output_diff():
def test_make_ufunc_out_shape_multi_output_diff():
func = lambda x: (np.random.rand(5, 3), np.random.rand(10, 4))
ary = np.ones((4, 100))
res1, res2 = wrap_xarray_ufunc(
func,
ary,
func_kwargs={"out_shape": ((5, 3), (10, 4))},
ufunc_kwargs={"n_dims": 1, "n_output": 2},
)
res1, res2 = make_ufunc(func, n_dims=1, n_output=2)(ary, out_shape=((5, 3), (10, 4)))
assert res1.shape == (*ary.shape[:-1], 5, 3)
assert res2.shape == (*ary.shape[:-1], 10, 4)

Expand Down

0 comments on commit d3f313b

Please sign in to comment.