Skip to content

Commit

Permalink
Introduce a partition_chunking argument into the MSv2 open_datatree m…
Browse files Browse the repository at this point in the history
…ethod
  • Loading branch information
sjperkins committed Oct 15, 2024
1 parent f151e81 commit c6f9725
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
11 changes: 7 additions & 4 deletions tests/test_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def test_open_datatree(simmed_ms):

# Works with default dask scheduler
with ExitStack() as stack:
dt = open_datatree(simmed_ms, chunks=chunks)
dt = open_datatree(simmed_ms, partition_chunks=chunks)
for ds in dt.values():
del ds.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)
Expand All @@ -165,7 +165,7 @@ def test_open_datatree(simmed_ms):
with ExitStack() as stack:
cluster = stack.enter_context(LocalCluster(processes=True, n_workers=4))
stack.enter_context(Client(cluster))
dt = open_datatree(simmed_ms, chunks=chunks)
dt = open_datatree(simmed_ms, partition_chunks=chunks)
for ds in dt.values():
del ds.attrs["creation_date"]
xt.assert_identical(dt, mem_dt)
Expand All @@ -186,7 +186,7 @@ def test_open_datatree_chunking(simmed_ms):
and partition-specific chunking"""
dt = open_datatree(
simmed_ms,
chunks={"time": 3, "frequency": 2},
partition_chunks={"time": 3, "frequency": 2},
)

for child in dt.children:
Expand All @@ -210,7 +210,10 @@ def test_open_datatree_chunking(simmed_ms):

dt = open_datatree(
simmed_ms,
chunks={"D=0": {"time": 2, "baseline": 2}, "D=1": {"time": 3, "frequency": 2}},
partition_chunks={
"D=0": {"time": 2, "baseline": 2},
"D=1": {"time": 3, "frequency": 2},
},
)

for child in dt.children:
Expand Down
17 changes: 14 additions & 3 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,7 +298,7 @@ def open_datatree(
self,
filename_or_obj: str | os.PathLike[Any] | BufferedIOBase | AbstractDataStore,
*,
chunks: Dict[str, Any] | None = None,
partition_chunks: Dict[str, Any] | None = None,
drop_variables: str | Iterable[str] | None = None,
partition_columns: List[str] | None = None,
auto_corrs: bool = True,
Expand All @@ -311,7 +311,7 @@ def open_datatree(
Args:
filename_or_obj: The path to the MSv2 CASA Measurement Set file.
chunks: Chunk sizes along each dimension,
partition_chunks: Chunk sizes along each dimension,
e.g. :code:`{{"time": 10, "frequency": 16}}`.
Individual partitions can be chunked differently by
partially (or fully) specifying a partition key: e.g.
Expand All @@ -331,6 +331,11 @@ def open_datatree(
"D=0,F=1": {{"time": 20, "frequency": 32}},
}}
.. note:: This argument overrides the reserved ``chunks`` argument
used by xarray to control chunking in Datasets and DataTrees.
It should be used instead of ``chunks`` when different
chunking is desired for different partitions.
drop_variables: Variables to drop from the dataset.
partition_columns: The columns to use for partitioning the Measurement set.
Defaults to :code:`{DEFAULT_PARTITION_COLUMNS}`.
Expand All @@ -355,7 +360,13 @@ def open_datatree(

structure = structure_factory()
datasets = {}
pchunks = promote_chunks(structure, chunks)

if not partition_chunks:
partition_chunks = kwargs.pop("chunks", None)
elif "chunks" in kwargs:
warnings.warn("`partition_chunks` overriding `chunks`")

pchunks = promote_chunks(structure, partition_chunks)

for partition_key in structure:
ds = xarray.open_dataset(
Expand Down

0 comments on commit c6f9725

Please sign in to comment.