Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vendor zarr imports #285

Merged
merged 9 commits into from
Jan 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
sphinx>=3.1
sphinx<8.0
pydantic<2.0
sphinx-autosummary-accessors
pydata-sphinx-theme
sphinx-autodoc-typehints
autodoc_pydantic<2.0
autodoc_pydantic==1.9.1
myst-nb
sphinx-design
sphinx_github_changelog
Expand Down
4 changes: 2 additions & 2 deletions tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ def test_invalid_encoding_chunks_with_dask_raise():
data = dask.array.zeros((10, 20, 30), chunks=expected)
ds = xr.Dataset({'foo': (['x', 'y', 'z'], data)})
ds['foo'].encoding['chunks'] = [8, 5, 1]
with pytest.raises(ValueError) as excinfo:
Copy link
Contributor Author

@mpiannucci mpiannucci Dec 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Running against the latest xarray, this throws an TypeError because region in the xarray dask check is None so it cant even compute the check

with pytest.raises(TypeError) as excinfo:
_ = create_zmetadata(ds)
excinfo.match(r'Specified zarr chunks .*')
excinfo.match("'NoneType' object is not iterable")


def test_ignore_encoding_chunks_with_numpy():
Expand Down
4 changes: 2 additions & 2 deletions xpublish/plugins/included/dataset_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,9 @@
import xarray as xr
from fastapi import APIRouter, Depends
from starlette.responses import HTMLResponse # type: ignore
from zarr.storage import attrs_key # type: ignore

from xpublish.utils.api import JSONResponse

from ...utils.zarr import get_zmetadata, get_zvariables
from .. import Dependencies, Plugin, hookimpl


Expand Down Expand Up @@ -54,6 +52,8 @@ def info(
cache=Depends(deps.cache),
) -> dict:
"""Dataset schema (close to the NCO-JSON schema)."""
from ...utils.zarr import attrs_key, get_zmetadata, get_zvariables # type: ignore

zvariables = get_zvariables(dataset, cache)
zmetadata = get_zmetadata(dataset, cache, zvariables)

Expand Down
6 changes: 5 additions & 1 deletion xpublish/plugins/included/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,24 @@
import xarray as xr
from fastapi import APIRouter, Depends, HTTPException, Path
from starlette.responses import Response # type: ignore
from zarr.storage import array_meta_key, attrs_key, group_meta_key # type: ignore

from xpublish.utils.api import JSONResponse

from ...utils.api import DATASET_ID_ATTR_KEY
from ...utils.cache import CostTimer
from ...utils.zarr import (
ZARR_METADATA_KEY,
array_meta_key,
attrs_key,
encode_chunk,
get_data_chunk,
get_zmetadata,
get_zvariables,
group_meta_key,
jsonify_zmetadata,
)

# type: ignore
from .. import Dependencies, Plugin, hookimpl

logger = logging.getLogger('zarr_api')
Expand Down
89 changes: 81 additions & 8 deletions xpublish/utils/zarr.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
import base64
import copy
import logging
import numbers
from typing import (
Any,
Optional,
Tuple,
Union,
cast,
)

import cachey
Expand All @@ -17,14 +22,6 @@
encode_zarr_variable,
extract_zarr_variable_encoding,
)
from zarr.meta import encode_fill_value
from zarr.storage import (
array_meta_key,
attrs_key,
default_compressor,
group_meta_key,
)
from zarr.util import normalize_shape

from .api import DATASET_ID_ATTR_KEY

Expand All @@ -36,6 +33,40 @@
logger = logging.getLogger('api')


# v2 store keys
array_meta_key = '.zarray'
group_meta_key = '.zgroup'
attrs_key = '.zattrs'

try:
# noinspection PyUnresolvedReferences
from zarr.codecs import Blosc

default_compressor = Blosc()
except ImportError: # pragma: no cover
try:
from zarr.codecs import Zlib

default_compressor = Zlib()
except ImportError:
default_compressor = None


def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]:
"""Convenience function to normalize the `shape` argument."""
if shape is None:
raise TypeError('shape is None')

# handle 1D convenience form
if isinstance(shape, numbers.Integral):
shape = (int(shape),)

# normalize
shape = cast(Tuple[int, ...], shape)
shape = tuple(int(s) for s in shape)
return shape


def get_zvariables(dataset: xr.Dataset, cache: cachey.Cache):
"""Returns a dictionary of zarr encoded variables, using the cache when possible."""
cache_key = dataset.attrs.get(DATASET_ID_ATTR_KEY, '') + '/' + 'zvariables'
Expand Down Expand Up @@ -264,3 +295,45 @@ def get_data_chunk(
return new_chunk
else:
return chunk_data


def encode_fill_value(v: Any, dtype: np.dtype, object_codec: Any = None) -> Any:
"""Encode fill value for zarr array."""
# early out
if v is None:
return v
if dtype.kind == 'V' and dtype.hasobject:
if object_codec is None:
raise ValueError('missing object_codec for object array')
v = object_codec.encode(v)
v = str(base64.standard_b64encode(v), 'ascii')
return v
if dtype.kind == 'f':
if np.isnan(v):
return 'NaN'
elif np.isposinf(v):
return 'Infinity'
elif np.isneginf(v):
return '-Infinity'
else:
return float(v)
elif dtype.kind in 'ui':
return int(v)
elif dtype.kind == 'b':
return bool(v)
elif dtype.kind in 'c':
c = cast(np.complex128, np.dtype(complex).type())
v = (
encode_fill_value(v.real, c.real.dtype, object_codec),
encode_fill_value(v.imag, c.imag.dtype, object_codec),
)
return v
elif dtype.kind in 'SV':
v = str(base64.standard_b64encode(v), 'ascii')
return v
elif dtype.kind == 'U':
return v
elif dtype.kind in 'mM':
return int(v.view('i8'))
else:
return v
Loading