Skip to content

Commit

Permalink
Utilitise xarray's preferred_chunks functionality (#44)
Browse files Browse the repository at this point in the history
  • Loading branch information
sjperkins authored Nov 5, 2024
1 parent 7e1f8cd commit 07bd809
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 75 deletions.
30 changes: 15 additions & 15 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -24,28 +24,28 @@ to be developed on well-understood MSv2 data.
>>> import xarray_ms
>>> from xarray.backends.api import datatree
>>> dt = open_datatree("/data/L795830_SB001_uv.MS/",
partition_chunks={"time": 2000, "baseline": 1000})
preferred_chunks={"time": 2000, "baseline_id": 1000})
>>> dt
<xarray.DataTree>
Group: /
└── Group: /DATA_DESC_ID=0,FIELD_ID=0,OBSERVATION_ID=0
│ Dimensions: (time: 28760, baseline: 2775, frequency: 16,
│ Dimensions: (time: 28760, baseline_id: 2775, frequency: 16,
│ polarization: 4, uvw_label: 3)
│ Coordinates:
│ antenna1_name (baseline) object 22kB ...
│ antenna2_name (baseline) object 22kB ...
│ baseline_id (baseline) int64 22kB ...
│ antenna1_name (baseline_id) object 22kB ...
│ antenna2_name (baseline_id) object 22kB ...
│ baseline_id (baseline_id) int64 22kB ...
* frequency (frequency) float64 128B 1.202e+08 ... 1.204e+08
* polarization (polarization) <U2 32B 'XX' 'XY' 'YX' 'YY'
* time (time) float64 230kB 1.601e+09 ... 1.601e+09
│ Dimensions without coordinates: baseline, uvw_label
│ Dimensions without coordinates: uvw_label
│ Data variables:
│ EFFECTIVE_INTEGRATION_TIME (time, baseline) float64 638MB ...
│ FLAG (time, baseline, frequency, polarization) uint8 5GB ...
│ TIME_CENTROID (time, baseline) float64 638MB ...
│ UVW (time, baseline, uvw_label) float64 2GB ...
│ VISIBILITY (time, baseline, frequency, polarization) complex64 41GB ...
│ WEIGHT (time, baseline, frequency, polarization) float32 20GB ...
│ EFFECTIVE_INTEGRATION_TIME (time, baseline_id) float64 638MB ...
│ FLAG (time, baseline_id, frequency, polarization) uint8 5GB ...
│ TIME_CENTROID (time, baseline_id) float64 638MB ...
│ UVW (time, baseline_id, uvw_label) float64 2GB ...
│ VISIBILITY (time, baseline_id, frequency, polarization) complex64 41GB ...
│ WEIGHT (time, baseline_id, frequency, polarization) float32 20GB ...
│ Attributes:
│ version: 4.0.0
│ creation_date: 2024-09-18T10:49:55.133908+00:00
Expand All @@ -54,9 +54,9 @@ to be developed on well-understood MSv2 data.
Dimensions: (antenna_name: 74,
cartesian_pos_label/ellipsoid_pos_label: 3)
Coordinates:
baseline_antenna1_name (baseline) object 22kB ...
baseline_antenna2_name (baseline) object 22kB ...
baseline_id (baseline) int64 22kB ...
baseline_antenna1_name (baseline_id) object 22kB ...
baseline_antenna2_name (baseline_id) object 22kB ...
baseline_id (baseline_id) int64 22kB ...
* frequency (frequency) float64 128B 1.202e+08 1.202e+08 ... 1.204e+08
* polarization (polarization) <U2 32B 'XX' 'XY' 'YX' 'YY'
* time (time) float64 230kB 1.601e+09 1.601e+09 ... 1.601e+09
Expand Down
3 changes: 3 additions & 0 deletions doc/source/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ Changelog

X.Y.Z (DD-MM-YYYY)
------------------
* Rename ``baseline`` dimension to ``baseline_id`` (:pr:`44`)
* Loosen xarray version requirement to \>= 2024.9.0 (:pr:`44`)
* Change ``partition_chunks`` to ``preferred_chunks`` (:pr:`44`)
* Allow arcae to vary in the 0.2.x range (:pr:`42`)
* Pin xarray to 2024.9.0 (:pr:`42`)
* Add test case for irregular grids (:pr:`39`, :pr:`40`, :pr:`41`)
Expand Down
15 changes: 8 additions & 7 deletions doc/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ For example, one could select select some specific dimensions out:
dt = open_datatree(ms,
partition_columns=["DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"])
subdt = dt.isel(time=slice(1, 3), baseline=[1, 3, 5], frequency=slice(2, 4))
subdt = dt.isel(time=slice(1, 3), baseline_id=[1, 3, 5], frequency=slice(2, 4))
subdt
At this point, the ``subdt`` DataTree is still lazy -- no Data variables have been loaded
Expand Down Expand Up @@ -103,22 +103,22 @@ Per-partition chunking

Different chunking may be desired, especially when applied to
different channelisation and polarisation configurations.
In these cases, the ``partition_chunks`` argument can be used
In these cases, the ``preferred_chunks`` argument can be used
to specify different chunking setups for each partition.

.. ipython:: python
dt = open_datatree(ms, partition_columns=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"],
partition_chunks={
chunks={},
preferred_chunks={
(("DATA_DESC_ID", 0),): {"time": 2, "frequency": 4},
(("DATA_DESC_ID", 1),): {"time": 3, "frequency": 2}})
See the ``partition_chunks`` argument of
:meth:`xarray_ms.backend.msv2.entrypoint.MSv2EntryPoint.open_datatree`
See the ``preferred_chunks`` argument of
:meth:`~xarray_ms.backend.msv2.entrypoint.MSv2EntryPoint.open_datatree`
for more information.


.. ipython:: python
dt
Expand All @@ -139,7 +139,8 @@ this to a zarr_ store.
dt = open_datatree(ms, partition_columns=[
"DATA_DESC_ID", "FIELD_ID", "OBSERVATION_ID"],
partition_chunks={
chunks={},
preferred_chunks={
(("DATA_DESC_ID", 0),): {"time": 2, "frequency": 4},
(("DATA_DESC_ID", 1),): {"time": 3, "frequency": 2}})
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ readme = "README.rst"
[tool.poetry.dependencies]
python = "^3.10"
pytest = {version = "^8.0.0", optional = true, extras = ["testing"]}
xarray = "^2024.9.0, < 2024.10.0"
xarray = "^2024.9.0"
dask = {version = "^2024.5.0", optional = true, extras = ["testing"]}
distributed = {version = "^2024.5.0", optional = true, extras = ["testing"]}
cacheout = "^0.16.0"
Expand Down
38 changes: 20 additions & 18 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_open_datatree(simmed_ms):

# Works with default dask scheduler
with ExitStack() as stack:
dt = open_datatree(simmed_ms, partition_chunks=chunks)
dt = open_datatree(simmed_ms, preferred_chunks=chunks)
for ds in dt.values():
del ds.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)
Expand All @@ -165,7 +165,7 @@ def test_open_datatree(simmed_ms):
with ExitStack() as stack:
cluster = stack.enter_context(LocalCluster(processes=True, n_workers=4))
stack.enter_context(Client(cluster))
dt = open_datatree(simmed_ms, partition_chunks=chunks)
dt = open_datatree(simmed_ms, preferred_chunks=chunks)
for ds in dt.values():
del ds.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)
Expand All @@ -186,32 +186,34 @@ def test_open_datatree_chunking(simmed_ms):
and partition-specific chunking"""
dt = open_datatree(
simmed_ms,
partition_chunks={"time": 3, "frequency": 2},
chunks={},
preferred_chunks={"time": 3, "frequency": 2},
)

for child in dt.children:
ds = dt[child].ds
if ds.attrs["data_description_id"] == 0:
assert dict(ds.chunks) == {
"time": (3, 2),
"baseline": (6,),
"baseline_id": (6,),
"frequency": (2, 2, 2, 2),
"polarization": (4,),
"uvw_label": (3,),
}
elif ds.attrs["data_description_id"] == 1:
assert dict(ds.chunks) == {
"time": (3, 2),
"baseline": (6,),
"baseline_id": (6,),
"frequency": (2, 2),
"polarization": (2,),
"uvw_label": (3,),
}

dt = open_datatree(
simmed_ms,
partition_chunks={
"D=0": {"time": 2, "baseline": 2},
chunks={},
preferred_chunks={
"D=0": {"time": 2, "baseline_id": 2},
"D=1": {"time": 3, "frequency": 2},
},
)
Expand All @@ -221,26 +223,26 @@ def test_open_datatree_chunking(simmed_ms):
if ds.attrs["data_description_id"] == 0:
assert ds.chunks == {
"time": (2, 2, 1),
"baseline": (2, 2, 2),
"baseline_id": (2, 2, 2),
"frequency": (8,),
"polarization": (4,),
"uvw_label": (3,),
}
elif ds.attrs["data_description_id"] == 1:
assert ds.chunks == {
"time": (3, 2),
"baseline": (6,),
"baseline_id": (6,),
"frequency": (2, 2),
"polarization": (2,),
"uvw_label": (3,),
}

with pytest.warns(UserWarning, match="`partition_chunks` overriding `chunks`"):
dt = open_datatree(
simmed_ms,
chunks={},
partition_chunks={
"D=0": {"time": 2, "baseline": 2},
"D=1": {"time": 3, "frequency": 2},
},
)
# with pytest.warns(UserWarning, match="`preferred_chunks` overriding `chunks`"):
# dt = open_datatree(
# simmed_ms,
# chunks={},
# preferred_chunks={
# "D=0": {"time": 2, "baseline_id": 2},
# "D=1": {"time": 3, "frequency": 2},
# },
# )
2 changes: 1 addition & 1 deletion tests/test_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _excise_rows(data_dict):


@pytest.mark.filterwarnings(
r"ignore:.*?rows missing from the full \(time, baseline\) grid"
r"ignore:.*?rows missing from the full \(time, baseline_id\) grid"
)
@pytest.mark.parametrize(
"simmed_ms",
Expand Down
4 changes: 2 additions & 2 deletions xarray_ms/backend/msv2/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def __init__(
self.shape = shape
self.dtype = np.dtype(dtype)

assert len(shape) >= 2, "(time, baseline) required"
assert len(shape) >= 2, "(time, baseline_ids) required"

def __getitem__(self, key) -> npt.NDArray:
return explicit_indexing_adapter(
Expand All @@ -67,7 +67,7 @@ def __getitem__(self, key) -> npt.NDArray:
def _getitem(self, key) -> npt.NDArray:
assert len(key) == len(self.shape)
expected_shape = tuple(slice_length(k, s) for k, s in zip(key, self.shape))
# Map the (time, baseline) coordinates onto row indices
# Map the (time, baseline_id) coordinates onto row indices
rows = self._structure_factory()[self._partition].row_map[key[:2]]
xkey = (rows.ravel(),) + key[2:]
row_shape = (rows.size,) + expected_shape[2:]
Expand Down
Loading

0 comments on commit 07bd809

Please sign in to comment.