Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use partition keys as datatree node names #6

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions xarray_ms/backend/msv2/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,8 @@ def initialise_default_args(
epoch: str | None,
table_factory: TableFactory | None,
partition_columns: List[str] | None,
partition_key: PartitionKeyT | None,
structure_factory: MSv2StructureFactory | None,
) -> Tuple[str, TableFactory, List[str], PartitionKeyT, MSv2StructureFactory]:
) -> Tuple[str, TableFactory, List[str], MSv2StructureFactory]:
"""
Ensures consistency when initialising default arguments from multiple locations
"""
Expand All @@ -98,16 +97,7 @@ def initialise_default_args(
structure_factory = structure_factory or MSv2StructureFactory(
table_factory, partition_columns, auto_corrs=auto_corrs
)
structure = structure_factory()
if partition_key is None:
partition_key = next(iter(structure.keys()))
warnings.warn(
f"No partition_key was supplied. Selected first partition {partition_key}"
)
elif partition_key not in structure:
raise ValueError(f"{partition_key} not in {list(structure.keys())}")

return epoch, table_factory, partition_columns, partition_key, structure_factory
return epoch, table_factory, partition_columns, structure_factory


class MSv2Store(AbstractWritableDataStore):
Expand Down Expand Up @@ -164,19 +154,28 @@ def open(
if not isinstance(ms, str):
raise ValueError("Measurement Sets paths must be strings")

epoch, table_factory, partition_columns, partition_key, structure_factory = (
epoch, table_factory, partition_columns, structure_factory = (
initialise_default_args(
ms,
ninstances,
auto_corrs,
epoch,
None,
partition_columns,
partition_key,
structure_factory,
)
)

structure = structure_factory()

if partition_key is None:
partition_key = next(iter(structure.keys()))
warnings.warn(
f"No partition_key was supplied. Selected first partition {partition_key}"
)
elif partition_key not in structure:
raise ValueError(f"{partition_key} not in {list(structure.keys())}")

return cls(
table_factory,
structure_factory,
Expand Down Expand Up @@ -332,16 +331,16 @@ def open_datatree(
else:
raise ValueError("Measurement Set paths must be strings")

epoch, _, partition_columns, _, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_columns, None, None
epoch, _, partition_columns, structure_factory = initialise_default_args(
ms, ninstances, auto_corrs, epoch, None, partition_columns, None
)

structure = structure_factory()
datasets = {}
chunks = kwargs.pop("chunks", None)
pchunks = promote_chunks(structure, chunks)

for i, partition_key in enumerate(structure):
for partition_key in structure:
ds = xarray.open_dataset(
ms,
drop_variables=drop_variables,
Expand All @@ -354,6 +353,8 @@ def open_datatree(
chunks=None if pchunks is None else pchunks[partition_key],
**kwargs,
)
datasets[str(i)] = ds

key = ",".join(f"{k}={v}" for k, v in sorted(partition_key))
datasets[key] = ds

return DataTree.from_dict(datasets)
Loading