Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(feat): support for zarr-python>=3.0.0b0 #1726

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ test = [
"loompy>=3.0.5",
"pytest>=8.2",
"pytest-cov>=2.10",
"zarr<3.0.0a0",
"zarr>=3.0.0b0",
"matplotlib",
"scikit-learn",
"openpyxl",
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/_core/sparse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,7 +505,7 @@ def append(self, sparse_matrix: ss.csr_matrix | ss.csc_matrix | SpArray) -> None
f"Matrices must have same format. Currently are "
f"{self.format!r} and {sparse_matrix.format!r}"
)
indptr_offset = len(self.group["indices"])
[indptr_offset] = self.group["indices"].shape
if self.group["indptr"].dtype == np.int32:
new_nnz = indptr_offset + len(sparse_matrix.indices)
if new_nnz >= np.iinfo(np.int32).max:
Expand Down
92 changes: 65 additions & 27 deletions src/anndata/_io/specs/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@
# Backwards compat sparse arrays
if "h5sparse_format" in elem.attrs:
return sparse_dataset(elem).to_memory()
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}
elif isinstance(elem, h5py.Dataset):
return h5ad.read_dataset(elem) # TODO: Handle legacy

Expand All @@ -161,7 +161,7 @@
# Backwards compat sparse arrays
if "h5sparse_format" in elem.attrs:
return sparse_dataset(elem).to_memory()
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}

Check warning on line 164 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L164

Added line #L164 was not covered by tests
elif isinstance(elem, ZarrArray):
return zarr.read_dataset(elem) # TODO: Handle legacy

Expand Down Expand Up @@ -334,7 +334,7 @@
@_REGISTRY.register_read(H5Group, IOSpec("dict", "0.1.0"))
@_REGISTRY.register_read(ZarrGroup, IOSpec("dict", "0.1.0"))
def read_mapping(elem: GroupStorageType, *, _reader: Reader) -> dict[str, AxisStorable]:
return {k: _reader.read_elem(v) for k, v in elem.items()}
return {k: _reader.read_elem(v) for k, v in dict(elem).items()}


@_REGISTRY.register_write(H5Group, dict, IOSpec("dict", "0.1.0"))
Expand Down Expand Up @@ -390,7 +390,7 @@
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
"""Write methods which underlying library handles natively."""
f.create_dataset(k, data=elem, **dataset_kwargs)
f.create_dataset(k, data=elem, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)


_REGISTRY.register_write(H5Group, CupyArray, IOSpec("array", "0.2.0"))(
Expand All @@ -411,8 +411,12 @@
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import dask.array as da
import zarr

g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
if Version(zarr.__version__) >= Version("3.0.0b0"):
g = f.require_array(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)
else:
g = f.require_dataset(k, shape=elem.shape, dtype=elem.dtype, **dataset_kwargs)

Check warning on line 419 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L419

Added line #L419 was not covered by tests
da.store(elem, g, lock=GLOBAL_LOCK)


Expand Down Expand Up @@ -505,23 +509,37 @@
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
import numcodecs

if Version(numcodecs.__version__) < Version("0.13"):
msg = "Old numcodecs version detected. Please update for improved performance and stability."
warnings.warn(msg)
# Workaround for https://github.com/zarr-developers/numcodecs/issues/514
if hasattr(elem, "flags") and not elem.flags.writeable:
elem = elem.copy()

f.create_dataset(
k,
shape=elem.shape,
dtype=object,
object_codec=numcodecs.VLenUTF8(),
**dataset_kwargs,
)
f[k][:] = elem
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
import numcodecs

Check warning on line 515 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L515

Added line #L515 was not covered by tests

if Version(numcodecs.__version__) < Version("0.13"):
msg = "Old numcodecs version detected. Please update for improved performance and stability."
warnings.warn(msg)

Check warning on line 519 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L517-L519

Added lines #L517 - L519 were not covered by tests
# Workaround for https://github.com/zarr-developers/numcodecs/issues/514
if hasattr(elem, "flags") and not elem.flags.writeable:
elem = elem.copy()

Check warning on line 522 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L521-L522

Added lines #L521 - L522 were not covered by tests

f.create_dataset(

Check warning on line 524 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L524

Added line #L524 was not covered by tests
k,
shape=elem.shape,
dtype=object,
object_codec=numcodecs.VLenUTF8(),
**dataset_kwargs,
)
f[k][:] = elem

Check warning on line 531 in src/anndata/_io/specs/methods.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/methods.py#L531

Added line #L531 was not covered by tests
else:
from zarr.codecs import VLenUTF8Codec

f.create_array(
k,
shape=elem.shape,
dtype=str,
codecs=[VLenUTF8Codec()],
**dataset_kwargs,
)
f[k][:] = elem


###############
Expand Down Expand Up @@ -576,7 +594,9 @@
):
from anndata.compat import _to_fixed_length_strings

f.create_dataset(k, data=_to_fixed_length_strings(elem), **dataset_kwargs)
f.create_dataset(
k, data=_to_fixed_length_strings(elem), shape=elem.shape, **dataset_kwargs
)


#################
Expand All @@ -602,9 +622,27 @@
if isinstance(f, H5Group) and "maxshape" not in dataset_kwargs:
dataset_kwargs = dict(maxshape=(None,), **dataset_kwargs)

g.create_dataset("data", data=value.data, **dataset_kwargs)
g.create_dataset("indices", data=value.indices, **dataset_kwargs)
g.create_dataset("indptr", data=value.indptr, dtype=indptr_dtype, **dataset_kwargs)
g.create_dataset(
"data",
data=value.data,
shape=value.data.shape,
dtype=value.data.dtype,
**dataset_kwargs,
)
g.create_dataset(
"indices",
data=value.indices,
shape=value.indices.shape,
dtype=value.indices.dtype,
**dataset_kwargs,
)
g.create_dataset(
"indptr",
data=value.indptr,
shape=value.indptr.shape,
dtype=indptr_dtype,
**dataset_kwargs,
)


write_csr = partial(write_sparse_compressed, fmt="csr")
Expand Down Expand Up @@ -1117,7 +1155,7 @@
_writer: Writer,
dataset_kwargs: Mapping[str, Any] = MappingProxyType({}),
):
return f.create_dataset(key, data=np.array(value), **dataset_kwargs)
return f.create_dataset(key, data=np.array(value), shape=(), **dataset_kwargs)


def write_hdf5_scalar(
Expand Down
17 changes: 13 additions & 4 deletions src/anndata/_io/specs/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
from types import MappingProxyType
from typing import TYPE_CHECKING, Generic, TypeVar

from packaging.version import Version

from anndata._io.utils import report_read_key_on_error, report_write_key_on_error
from anndata._types import Read, ReadDask, _ReadDaskInternal, _ReadInternal
from anndata.compat import DaskArray, _read_attr
from anndata.compat import DaskArray, ZarrGroup, _read_attr

if TYPE_CHECKING:
from collections.abc import Callable, Generator, Iterable
Expand Down Expand Up @@ -341,11 +343,18 @@
return lambda *_, **__: None

# Normalize k to absolute path
if not PurePosixPath(k).is_absolute():
k = str(PurePosixPath(store.name) / k)
if isinstance(store, ZarrGroup):
import zarr

if Version(zarr.__version__) < Version("3.0.0b0"):
if not PurePosixPath(k).is_absolute():
k = str(PurePosixPath(store.name) / k)

Check warning on line 351 in src/anndata/_io/specs/registry.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/_io/specs/registry.py#L350-L351

Added lines #L350 - L351 were not covered by tests

if k == "/":
store.clear()
if isinstance(store, ZarrGroup):
store.store.clear()
else:
store.clear()
elif k in store:
del store[k]

Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def report_read_key_on_error(func):
>>> @report_read_key_on_error
... def read_arr(group):
... raise NotImplementedError()
>>> z = zarr.open("tmp.zarr")
>>> z = zarr.open("tmp.zarr", mode="w")
>>> z["X"] = [1, 2, 3]
>>> read_arr(z["X"]) # doctest: +SKIP
"""
Expand Down Expand Up @@ -228,7 +228,7 @@ def report_write_key_on_error(func):
>>> @report_write_key_on_error
... def write_arr(group, key, val):
... raise NotImplementedError()
>>> z = zarr.open("tmp.zarr")
>>> z = zarr.open("tmp.zarr", mode="w")
>>> X = [1, 2, 3]
>>> write_arr(z, "X", X) # doctest: +SKIP
"""
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/_io/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def write_zarr(
f.attrs.setdefault("encoding-version", "0.1.0")

def callback(func, s, k, elem, dataset_kwargs, iospec):
if chunks is not None and not isinstance(elem, sparse.spmatrix) and k == "/X":
if chunks is not None and not isinstance(elem, sparse.spmatrix) and k == "X":
dataset_kwargs = dict(dataset_kwargs, chunks=chunks)
func(s, k, elem, dataset_kwargs=dataset_kwargs)

Expand Down Expand Up @@ -73,7 +73,7 @@ def callback(func, elem_name: str, elem, iospec):
return AnnData(
**{
k: read_dispatched(v, callback)
for k, v in elem.items()
for k, v in dict(elem).items()
if not k.startswith("raw.")
}
)
Expand Down
4 changes: 2 additions & 2 deletions src/anndata/compat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,8 +80,8 @@ def __exit__(self, *_exc_info) -> None:
#############################

if find_spec("zarr") or TYPE_CHECKING:
from zarr.core import Array as ZarrArray
from zarr.hierarchy import Group as ZarrGroup
from zarr import Array as ZarrArray
from zarr import Group as ZarrGroup
else:

class ZarrArray:
Expand Down
2 changes: 1 addition & 1 deletion src/anndata/experimental/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def callback(func, elem_name: str, elem, iospec):
elif iospec.encoding_type == "array":
return elem
elif iospec.encoding_type == "dict":
return {k: read_as_backed(v) for k, v in elem.items()}
return {k: read_as_backed(v) for k, v in dict(elem).items()}
else:
return func(elem)

Expand Down
6 changes: 3 additions & 3 deletions src/anndata/tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1040,17 +1040,17 @@
]

if find_spec("zarr") or TYPE_CHECKING:
from zarr import DirectoryStore
from zarr.storage import LocalStore
else:

class DirectoryStore:
class LocalStore:

Check warning on line 1046 in src/anndata/tests/helpers.py

View check run for this annotation

Codecov / codecov/patch

src/anndata/tests/helpers.py#L1046

Added line #L1046 was not covered by tests
def __init__(self, *_args, **_kwargs) -> None:
cls_name = type(self).__name__
msg = f"zarr must be imported to create a {cls_name} instance."
raise ImportError(msg)


class AccessTrackingStore(DirectoryStore):
class AccessTrackingStore(LocalStore):
_access_count: Counter[str]
_accessed_keys: dict[str, list[str]]

Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def tokenize_anndata(adata: ad.AnnData):
res.extend([tokenize(adata.obs), tokenize(adata.var)])
for attr in ["obsm", "varm", "obsp", "varp", "layers"]:
elem = getattr(adata, attr)
res.append(tokenize(list(elem.items())))
res.append(tokenize(list(dict(elem).items())))
res.append(joblib.hash(adata.uns))
if adata.raw is not None:
res.append(tokenize(adata.raw.to_adata()))
Expand Down
5 changes: 4 additions & 1 deletion tests/test_backed_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,10 @@ def read_zarr_backed(path):
def callback(func, elem_name, elem, iospec):
if iospec.encoding_type == "anndata" or elem_name.endswith("/"):
return AnnData(
**{k: read_dispatched(v, callback) for k, v in elem.items()}
**{
k: read_dispatched(v, callback)
for k, v in dict(elem).items()
}
)
if iospec.encoding_type in {"csc_matrix", "csr_matrix"}:
return sparse_dataset(elem)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_io_elementwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def test_read_zarr_from_group(tmp_path, consolidated):
write_elem(z, "table/table", adata)

if consolidated:
zarr.convenience.consolidate_metadata(z.store)
zarr.consolidate_metadata(z.store)

if consolidated:
read_func = zarr.open_consolidated
Expand Down
Loading