Skip to content

Commit

Permalink
work on accessor return types (#19)
Browse files Browse the repository at this point in the history
  • Loading branch information
OriolAbril authored Aug 13, 2024
1 parent 1fb6c7d commit ad46a84
Showing 1 changed file with 19 additions and 13 deletions.
32 changes: 19 additions & 13 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ def filter_vars(self, var_names=None, filter_vars=None):

def _apply(self, fun, dims, **kwargs):
"""Apply a function to all variables subsetting dims to existing dimensions."""
if isinstance(fun, str):
fun = get_function(fun)
return xr.Dataset(
{
var_name: fun(da, dims=update_dims(dims, da), **kwargs)
Expand All @@ -105,34 +107,32 @@ def _apply(self, fun, dims, **kwargs):
def eti(self, prob=None, dims=None, **kwargs):
"""Compute the equal tail interval of all the variables in the dataset."""
kwargs["prob"] = prob
return self._apply(get_function("eti"), dims=dims, **kwargs)
return self._apply("eti", dims=dims, **kwargs)

def hdi(self, prob=None, dims=None, **kwargs):
"""Compute hdi on all variables in the dataset."""
kwargs["prob"] = prob
return self._apply(get_function("hdi"), dims=dims, **kwargs)
return self._apply("hdi", dims=dims, **kwargs)

def ess(self, dims=None, method="bulk", relative=False, prob=None):
"""Compute the ess of all the variables in the dataset."""
return self._apply(
get_function("ess"), dims=dims, method=method, relative=relative, prob=prob
)
return self._apply("ess", dims=dims, method=method, relative=relative, prob=prob)

def rhat(self, dims=None, method="rank"):
"""Compute the rhat of all the variables in the dataset."""
return self._apply(get_function("rhat"), dims=dims, method=method)
return self._apply("rhat", dims=dims, method=method)

def mcse(self, dims=None, method="mean", prob=None):
"""Compute the mcse of all the variables in the dataset."""
return self._apply(get_function("mcse"), dims=dims, method=method, prob=prob)
return self._apply("mcse", dims=dims, method=method, prob=prob)

def kde(self, dims=None, **kwargs):
"""Compute the KDE for all variables in the dataset."""
return self._apply(get_function("kde"), dims=dims, **kwargs)
return self._apply("kde", dims=dims, **kwargs)

def histogram(self, dims=None, **kwargs):
"""Compute the KDE for all variables in the dataset."""
return self._apply(get_function("histogram"), dims=dims, **kwargs)
return self._apply("histogram", dims=dims, **kwargs)

def ecdf(self, dims=None, **kwargs):
"""Compute the ecdf for all variables in the dataset."""
Expand Down Expand Up @@ -160,12 +160,18 @@ def _process_input(self, group, method):
return self._obj

def _apply(self, fun_name, dims, group, **kwargs):
if isinstance(group, str):
group = [group]
return DataTree.from_dict(
{
var_name: get_function(fun_name)(da, dims=update_dims(dims, da), **kwargs)
for var_name, da in self._process_input(group, fun_name).items()
},
name=group,
group_i: xr.Dataset(
{
var_name: get_function(fun_name)(da, dims=update_dims(dims, da), **kwargs)
for var_name, da in self._process_input(group_i, fun_name).items()
}
)
for group_i in group
}
)

def filter_vars(self, group="posterior", var_names=None, filter_vars=None):
Expand Down

0 comments on commit ad46a84

Please sign in to comment.