Skip to content

Commit

Permalink
Some maintenance and fixes (#2)
Browse files Browse the repository at this point in the history
* always return arrays of grid_len

* import gaussian from scipy.signal.windows

* update dependencies

* fix kwargs in kde

* update accessors to make them more robust

* dry on accessor code
  • Loading branch information
OriolAbril authored Mar 26, 2024
1 parent 8e40361 commit a40917b
Show file tree
Hide file tree
Showing 5 changed files with 52 additions and 17 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11"]
python-version: ["3.10", "3.11", "3.12"]
fail-fast: false
steps:
- uses: actions/checkout@v3
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ build-backend = "flit_core.buildapi"
[project]
name = "arviz-stats"
readme = "README.md"
requires-python = ">=3.9"
requires-python = ">=3.10"
license = {file = "LICENSE"}
authors = [
{name = "ArviZ team", email = "[email protected]"}
Expand All @@ -19,14 +19,14 @@ classifiers = [
"Operating System :: OS Independent",
"Programming Language :: Python",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
]
dynamic = ["version", "description"]
dependencies = [
"numpy>=1.20",
"xarray>=0.18.0",
"numpy>=1.23",
"xarray>=2022.6.0",
"xarray-datatree",
"arviz-base @ git+https://github.com/arviz-devs/arviz-base",
"xarray-einstats",
Expand Down
44 changes: 39 additions & 5 deletions src/arviz_stats/accessors.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,15 @@ class UnsetDefault:
pass


def update_dims(dims, da):
"""Update dims to contain only those present in da."""
if dims is None:
return None
if isinstance(dims, str):
dims = [dims]
return [dim for dim in dims if dim in da.dims]


unset = UnsetDefault()


Expand All @@ -37,10 +46,23 @@ def hdi(self, prob=None, dims=None, **kwargs):
"""Compute the highest density interval on the DataArray."""
return get_function("hdi")(self._obj, prob=prob, dims=dims, **kwargs)

def kde(self, dims=None, **kwargs):
"""Compute the KDE on the DataArray."""
return get_function("kde")(self._obj, dims=dims, **kwargs)


@xr.register_dataset_accessor("azstats")
class AzStatsDsAccessor(_BaseAccessor):
"""ArviZ stats accessor class for Datasets."""
"""ArviZ stats accessor class for Datasets.
Notes
-----
Whenever "dims" indicates a set of dimensions that are to be reduced, the behaviour
should be to reduce all present dimensions and ignore the ones not present.
Thus, they can't use :meth:`.Dataset.map` and instead we must manually loop over variables
in the dataset, remove elements from dims if necessary and afterwards rebuild the output
Dataset.
"""

@property
def ds(self):
Expand All @@ -66,21 +88,33 @@ def filter_vars(self, var_names=None, filter_vars=None):
self._obj = self._obj[var_names]
return self

def _apply(self, fun, dims, **kwargs):
"""Apply a function to all variables subsetting dims to existing dimensions."""
return xr.Dataset(
{
var_name: fun(da, dims=update_dims(dims, da), **kwargs)
for var_name, da in self._obj.items()
}
)

def eti(self, prob=None, dims=None, **kwargs):
"""Compute the equal tail interval of all the variables in the dataset."""
return self._obj.map(get_function("eti"), prob=prob, dims=dims, **kwargs)
kwargs["prob"] = prob
return self._apply(get_function("eti"), dims=dims, **kwargs)

def hdi(self, prob=None, dims=None, **kwargs):
"""Compute hdi on all variables in the dataset."""
return self._obj.map(get_function("hdi"), prob=prob, dims=dims, **kwargs)
kwargs["prob"] = prob
return self._apply(get_function("hdi"), dims=dims, **kwargs)

def kde(self, dims=None, **kwargs):
"""Compute the KDE for all variables in the dataset."""
return self._obj.map(get_function("kde"), dims=dims, **kwargs)
return self._apply(get_function("kde"), dims=dims, **kwargs)

def ecdf(self, dims=None, **kwargs):
"""Compute the ecdf for all variables in the dataset."""
return self._obj.map(ecdf, dims=dims, **kwargs).rename({"ecdf_axis": "plot_axis"})
# TODO: implement ecdf here so it doesn't depend on numba
return self._apply(ecdf, dims=dims, **kwargs).rename(ecdf_axis="plot_axis")


@register_datatree_accessor("azstats")
Expand Down
9 changes: 5 additions & 4 deletions src/arviz_stats/base/density.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from arviz_base import rcParams
from scipy.fftpack import fft
from scipy.optimize import brentq
from scipy.signal import convolve, convolve2d, gaussian # pylint: disable=no-name-in-module
from scipy.signal import convolve, convolve2d
from scipy.signal.windows import gaussian
from scipy.sparse import coo_matrix
from scipy.special import ive # pylint: disable=no-name-in-module

Expand Down Expand Up @@ -393,7 +394,7 @@ def kde(da, dims=None, grid_len=512, **kwargs):
return out.assign_coords({"bw" if da.name is None else f"bw_{da.name}": bw})


def _kde(x, circular=False, **kwargs):
def _kde(x, circular=False, grid_len=512, **kwargs):
"""One dimensional density estimation.
It is a wrapper around ``kde_linear()`` and ``kde_circular()``.
Expand Down Expand Up @@ -505,7 +506,7 @@ def _kde(x, circular=False, **kwargs):
if x.size == 0 or np.all(x == x[0]):
warnings.warn("Your data appears to have a single value or no finite values")

return np.zeros(2), np.array([np.nan] * 2), np.nan
return np.zeros(grid_len), np.full(grid_len, np.nan), np.nan

if circular:
if circular == "degrees":
Expand All @@ -514,7 +515,7 @@ def _kde(x, circular=False, **kwargs):
else:
kde_fun = kde_linear

return kde_fun(x, **kwargs)
return kde_fun(x, grid_len=grid_len, **kwargs)


def kde_linear(
Expand Down
6 changes: 3 additions & 3 deletions tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
envlist =
check
docs
{py39,py310,py311}{,-coverage}
{py310,py311,py312}{,-coverage}
# See https://tox.readthedocs.io/en/latest/example/package.html#flit
isolated_build = True
isolated_build_env = build

[gh-actions]
python =
3.9: py39
3.10: check, py310
3.11: py311
3.12: py312

[testenv]
basepython =
py39: python3.9
py310: python3.10
py311: python3.11
py312: python3.12
# See https://github.com/tox-dev/tox/issues/1548
{check,docs,cleandocs,viewdocs,build}: python3
setenv =
Expand Down

0 comments on commit a40917b

Please sign in to comment.