Skip to content

Commit

Permalink
Backport PR #1729: (feat): support ellipsis indexing (#1734)
Browse files Browse the repository at this point in the history
* Backport PR #1729: (feat): support ellipsis indexing

* (fix): ellipsis type

* (fix): patch versions?
  • Loading branch information
ilan-gold authored Oct 30, 2024
1 parent 34ab179 commit 535dd52
Show file tree
Hide file tree
Showing 8 changed files with 136 additions and 16 deletions.
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@
("py:class", "awkward.highlevel.Array"),
("py:class", "anndata._core.sparse_dataset.BaseCompressedSparseDataset"),
("py:obj", "numpy._typing._array_like._ScalarType_co"),
# https://github.com/tox-dev/sphinx-autodoc-typehints/issues/498
("py:class", "types.EllipsisType"),
]


Expand Down
1 change: 1 addition & 0 deletions docs/release-notes/1729.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add support for ellipsis indexing of the {class}`~anndata.AnnData` object {user}`ilan-gold`
3 changes: 2 additions & 1 deletion src/anndata/_core/anndata.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,8 +56,9 @@
from typing import Any, Literal

from .._types import ArrayDataStructureType
from ..compat import Index1D
from .aligned_mapping import AxisArraysView, LayersView, PairwiseArraysView
from .index import Index, Index1D
from .index import Index


# for backwards compat
Expand Down
36 changes: 22 additions & 14 deletions src/anndata/_core/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,8 @@ def _normalize_indices(
if isinstance(index, pd.Series):
index: Index = index.values
if isinstance(index, tuple):
if len(index) > 2:
raise ValueError("AnnData can only be sliced in rows and columns.")
# deal with pd.Series
# TODO: The series should probably be aligned first
if isinstance(index[1], pd.Series):
index = index[0], index[1].values
if isinstance(index[0], pd.Series):
index = index[0].values, index[1]
index = tuple(i.values if isinstance(i, pd.Series) else i for i in index)
ax0, ax1 = unpack_index(index)
ax0 = _normalize_index(ax0, names0)
ax1 = _normalize_index(ax1, names1)
Expand Down Expand Up @@ -105,8 +99,7 @@ def name_idx(i):
"are not valid obs/ var names or indices."
)
return positions # np.ndarray[int]
else:
raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}")
raise IndexError(f"Unknown indexer {indexer!r} of type {type(indexer)}")


def _fix_slice_bounds(s: slice, length: int) -> slice:
Expand All @@ -130,13 +123,28 @@ def _fix_slice_bounds(s: slice, length: int) -> slice:

def unpack_index(index: Index) -> tuple[Index1D, Index1D]:
if not isinstance(index, tuple):
if index is Ellipsis:
index = slice(None)
return index, slice(None)
elif len(index) == 2:
num_ellipsis = sum(i is Ellipsis for i in index)
if num_ellipsis > 1:
raise IndexError("an index can only have a single ellipsis ('...')")
# If index has Ellipsis, filter it out (and if not, error)
if len(index) > 2:
if not num_ellipsis:
raise IndexError("Received a length 3 index without an ellipsis")
index = tuple(i for i in index if i is not Ellipsis)
return index
elif len(index) == 1:
return index[0], slice(None)
else:
raise IndexError("invalid number of indices")
# If index has Ellipsis, replace it with slice
if len(index) == 2:
index = tuple(slice(None) if i is Ellipsis else i for i in index)
return index
if len(index) == 1:
index = index[0]
if index is Ellipsis:
index = slice(None)
return index, slice(None)
raise IndexError("invalid number of indices")


@singledispatch
Expand Down
17 changes: 16 additions & 1 deletion src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,28 @@
if TYPE_CHECKING:
from typing import Any

if sys.version_info >= (3, 10):
from types import EllipsisType
else:
EllipsisType = type(Ellipsis)


class Empty:
pass


EllipsisType = type(Ellipsis)
Index1D = Union[slice, int, str, np.int64, np.ndarray]
Index = Union[Index1D, tuple[Index1D, Index1D], spmatrix]
IndexRest = Union[Index1D, EllipsisType]
Index = Union[
IndexRest,
tuple[Index1D, IndexRest],
tuple[IndexRest, Index1D],
tuple[Index1D, Index1D, EllipsisType],
tuple[EllipsisType, Index1D, Index1D],
tuple[Index1D, EllipsisType, Index1D],
spmatrix,
]
H5Group = h5py.Group
H5Array = h5py.Dataset
H5File = h5py.File
Expand Down
53 changes: 53 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,63 @@
from __future__ import annotations

from typing import TYPE_CHECKING

import pytest

from anndata.tests.helpers import subset_func # noqa: F401

if TYPE_CHECKING:
from types import EllipsisType


@pytest.fixture
def backing_h5ad(tmp_path):
return tmp_path / "test.h5ad"


@pytest.fixture(
params=[
pytest.param((..., (slice(None), slice(None))), id="ellipsis"),
pytest.param(((...,), (slice(None), slice(None))), id="ellipsis_tuple"),
pytest.param(
((..., slice(0, 10)), (slice(None), slice(0, 10))), id="obs-ellipsis"
),
pytest.param(
((slice(0, 10), ...), (slice(0, 10), slice(None))), id="var-ellipsis"
),
pytest.param(
((slice(0, 10), slice(0, 10), ...), (slice(0, 10), slice(0, 10))),
id="obs-var-ellipsis",
),
pytest.param(
((..., slice(0, 10), slice(0, 10)), (slice(0, 10), slice(0, 10))),
id="ellipsis-obs-var",
),
pytest.param(
((slice(0, 10), ..., slice(0, 10)), (slice(0, 10), slice(0, 10))),
id="obs-ellipsis-var",
),
]
)
def ellipsis_index_with_equivalent(
request,
) -> tuple[tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice]]:
return request.param


@pytest.fixture
def ellipsis_index(
ellipsis_index_with_equivalent: tuple[
tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice]
],
) -> tuple[EllipsisType | slice, ...] | EllipsisType:
return ellipsis_index_with_equivalent[0]


@pytest.fixture
def equivalent_ellipsis_index(
ellipsis_index_with_equivalent: tuple[
tuple[EllipsisType | slice, ...] | EllipsisType, tuple[slice, slice]
],
) -> tuple[slice, slice]:
return ellipsis_index_with_equivalent[1]
12 changes: 12 additions & 0 deletions tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
if TYPE_CHECKING:
from collections.abc import Callable, Generator, Sequence
from pathlib import Path
from types import EllipsisType

from _pytest.mark import ParameterSet
from numpy.typing import ArrayLike, NDArray
Expand Down Expand Up @@ -126,6 +127,17 @@ def test_backed_indexing(
assert_equal(csr_mem[:, var_idx].X, dense_disk[:, var_idx].X)


def test_backed_ellipsis_indexing(
ondisk_equivalent_adata: tuple[AnnData, AnnData, AnnData, AnnData],
ellipsis_index: tuple[EllipsisType | slice, ...] | EllipsisType,
equivalent_ellipsis_index: tuple[slice, slice],
):
csr_mem, csr_disk, csc_disk, _ = ondisk_equivalent_adata

assert_equal(csr_mem.X[equivalent_ellipsis_index], csr_disk.X[ellipsis_index])
assert_equal(csr_mem.X[equivalent_ellipsis_index], csc_disk.X[ellipsis_index])


def make_randomized_mask(size: int) -> np.ndarray:
randomized_mask = np.zeros(size, dtype=bool)
inds = np.random.choice(size, 20, replace=False)
Expand Down
28 changes: 28 additions & 0 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from contextlib import ExitStack
from copy import deepcopy
from operator import mul
from typing import TYPE_CHECKING

import joblib
import numpy as np
Expand All @@ -29,6 +30,9 @@
)
from anndata.utils import asarray

if TYPE_CHECKING:
from types import EllipsisType

IGNORE_SPARSE_EFFICIENCY_WARNING = pytest.mark.filterwarnings(
"ignore:Changing the sparsity structure:scipy.sparse.SparseEfficiencyWarning"
)
Expand Down Expand Up @@ -767,6 +771,30 @@ def test_dataframe_view_index_setting():
assert a2.obs.index.values.tolist() == ["a", "b"]


def test_ellipsis_index(
ellipsis_index: tuple[EllipsisType | slice, ...] | EllipsisType,
equivalent_ellipsis_index: tuple[slice, slice],
matrix_type,
):
adata = gen_adata((10, 10), X_type=matrix_type, **GEN_ADATA_DASK_ARGS)
subset_ellipsis = adata[ellipsis_index]
subset = adata[equivalent_ellipsis_index]
assert_equal(subset_ellipsis, subset)


@pytest.mark.parametrize(
("index", "expected_error"),
[
((..., 0, ...), r"only have a single ellipsis"),
((0, 0, 0), r"Received a length 3 index"),
],
ids=["ellipsis-int-ellipsis", "int-int-int"],
)
def test_index_3d_errors(index: tuple[int | EllipsisType, ...], expected_error: str):
with pytest.raises(IndexError, match=expected_error):
gen_adata((10, 10))[index]


# @pytest.mark.parametrize("dim", ["obs", "var"])
# @pytest.mark.parametrize(
# ("idx", "pat"),
Expand Down

0 comments on commit 535dd52

Please sign in to comment.