-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Encode Dataset attributes containing Datasets as JSON
- Loading branch information
Showing
6 changed files
with
152 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
import xarray.testing as xt | ||
from xarray.backends.api import open_dataset, open_datatree | ||
|
||
from xarray_ms import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr | ||
|
||
|
||
def test_dataset_roundtrip(simmed_ms, tmp_path): | ||
ds = open_dataset(simmed_ms) | ||
zarr_path = tmp_path / "test_dataset.zarr" | ||
xds_to_zarr(ds, zarr_path, compute=True, consolidated=True) | ||
ds2 = xds_from_zarr(zarr_path) | ||
xt.assert_identical(ds, ds2) | ||
|
||
|
||
def test_datatree_roundtrip(simmed_ms, tmp_path): | ||
dt = open_datatree(simmed_ms) | ||
zarr_path = tmp_path / "test_datatree.zarr" | ||
xdt_to_zarr(dt, zarr_path, compute=True, consolidated=True) | ||
dt2 = xdt_from_zarr(zarr_path) | ||
xt.assert_identical(dt, dt2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
__all__ = ["xds_from_zarr", "xdt_from_zarr", "xds_to_zarr", "xdt_to_zarr"] | ||
|
||
from xarray_ms.core import xds_from_zarr, xds_to_zarr, xdt_from_zarr, xdt_to_zarr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import json | ||
|
||
import xarray | ||
from xarray.backends.api import open_datatree | ||
from xarray.core.dataset import Dataset | ||
from xarray.core.datatree import DataTree | ||
|
||
|
||
def encode_attributes(ds: Dataset) -> Dataset: | ||
"""Encode the antenna_xds attribute of a Dataset.""" | ||
|
||
# Attempt to encode the the antenna_xds attribute | ||
ant_xds = ds.attrs.get("antenna_xds", None) | ||
if ant_xds is None: | ||
return ds | ||
elif isinstance(ant_xds, Dataset): | ||
ant_xds = json.dumps(ant_xds.to_dict()) | ||
return ds.assign_attrs(antenna_xds=ant_xds) | ||
else: | ||
raise TypeError( | ||
f"antenna_xds attribute must be an xarray Dataset " | ||
f"but a {type(ant_xds)} was present" | ||
) | ||
|
||
|
||
def decode_attributes(ds: Dataset) -> Dataset: | ||
"""Decode the antenna_xds attribute of a Dataset.""" | ||
# Attempt to decode the the antenna_xds attribute | ||
ant_xds = ds.attrs["antenna_xds"] | ||
if isinstance(ant_xds, str): | ||
antenna_dict = json.loads(ant_xds) | ||
ant_ds = Dataset.from_dict(antenna_dict) | ||
return ds.assign_attrs(antenna_xds=ant_ds) | ||
elif isinstance(ant_xds, Dataset): | ||
return ds | ||
else: | ||
raise TypeError( | ||
f"antenna_xds must be an xarray Dataset or a JSON encoded Dataset " | ||
f"but a {type(ant_xds)} was present" | ||
) | ||
|
||
|
||
def xds_from_zarr(*args, **kwargs): | ||
"""Read a Measurement Set-like :class:`~xarray.Dataset` from a Zarr store. | ||
Thin wrapper around :func:`xarray.open_zarr`.""" | ||
return decode_attributes(xarray.open_zarr(*args, **kwargs)) | ||
|
||
|
||
def xds_to_zarr(ds: Dataset, *args, **kwargs) -> None: | ||
"""Write a Measurement Set-like :class:`~xarray.Dataset` to a Zarr store. | ||
Thin wrapper around :meth:`xarray.Dataset.to_zarr`. | ||
""" | ||
return encode_attributes(ds).to_zarr(*args, **kwargs) | ||
|
||
|
||
def xdt_from_zarr(*args, **kwargs): | ||
"""Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree` | ||
from a Zarr store. | ||
Thin wrapper around :func:`xarray.backends.api.open_datatree`.""" | ||
return open_datatree(*args, **kwargs).map_over_subtree(decode_attributes) | ||
|
||
|
||
def xdt_to_zarr(dt: DataTree, *args, **kwargs) -> None: | ||
"""Read a Measurement Set-like :class:`~xarray.core.datatree.DataTree` | ||
to a Zarr store | ||
Thin wrapper around :meth:`xarray.core.datatree.DataTree.to_zarr`. | ||
""" | ||
return dt.map_over_subtree(encode_attributes).to_zarr(*args, **kwargs) |