Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Raise when reduced dims are chunked in map_blocks #1482

Merged
merged 2 commits into from
Sep 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Bug fixes
^^^^^^^^^
* Fixed an error in the `pytest` configuration that prevented copying of testing data to thread-safe caches of workers under certain conditions (this should always occur). (:pull:`1473`).
* Coincidentally, this also fixes an error that caused `pytest` to error-out when invoked without an active internet connection. Running `pytest` without network access is now supported (requires cached testing data). (:issue:`1468`).
* Calling a ``sdba.map_blocks``-wrapped function with data chunked along the reduced dimensions will raise an error. This forbids chunking the trained dataset along the distribution dimensions, for example. (:issue:`1481`, :pull:`1482`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
13 changes: 13 additions & 0 deletions tests/test_sdba/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ def func(ds, *, dim):
).load()
assert set(data.data.dims) == {"dayofyear"}
assert "leftover" in data


def test_map_blocks_error(tas_series):
tas = tas_series(np.arange(366), start="2000-01-01")
tas = tas.expand_dims(lat=[1, 2, 3, 4]).chunk(lat=1)

# Test dim parsing
@map_blocks(reduces=["lat"], data=[])
def func(ds, *, group, lon=None):
return ds.tas.rename("data").to_dataset()

with pytest.raises(ValueError, match="cannot be chunked"):
func(xr.Dataset(dict(tas=tas)), group="time")
52 changes: 31 additions & 21 deletions xclim/sdba/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def duck_empty(dims, sizes, dtype="float64", chunks=None):
def _decode_cf_coords(ds):
"""Decode coords in-place."""
crds = xr.decode_cf(ds.coords.to_dataset())
for crdname in ds.coords.keys():
for crdname in list(ds.coords.keys()):
ds[crdname] = crds[crdname]
# decode_cf introduces an encoding key for the dtype, which can confuse the netCDF writer
dtype = ds[crdname].encoding.get("dtype")
Expand Down Expand Up @@ -557,26 +557,6 @@ def _map_blocks(ds, **kwargs):
) and group is None:
raise ValueError("Missing required `group` argument.")

if uses_dask(ds):
# Use dask if any of the input is dask-backed.
chunks = (
dict(ds.chunks)
if isinstance(ds, xr.Dataset)
else dict(zip(ds.dims, ds.chunks))
)
if group is not None:
badchunks = {
dim: chunks.get(dim)
for dim in group.add_dims + [group.dim]
if len(chunks.get(dim, [])) > 1
}
if badchunks:
raise ValueError(
f"The dimension(s) over which we group cannot be chunked ({badchunks})."
)
else:
chunks = None

# Make translation dict
if group is not None:
placeholders = {
Expand All @@ -602,6 +582,36 @@ def _map_blocks(ds, **kwargs):
f"Dimension {dim} is meant to be added by the computation but it is already on one of the inputs."
)

if uses_dask(ds):
# Use dask if any of the input is dask-backed.
chunks = (
dict(ds.chunks)
if isinstance(ds, xr.Dataset)
else dict(zip(ds.dims, ds.chunks))
)
badchunks = {}
if group is not None:
badchunks.update(
{
dim: chunks.get(dim)
for dim in group.add_dims + [group.dim]
if len(chunks.get(dim, [])) > 1
}
)
badchunks.update(
{
dim: chunks.get(dim)
for dim in reduced_dims
if len(chunks.get(dim)) > 1
}
)
if badchunks:
raise ValueError(
f"The dimension(s) over which we group, reduce or interpolate cannot be chunked ({badchunks})."
)
else:
chunks = None

# Dimensions untouched by the function.
base_dims = list(set(ds.dims) - set(new_dims) - set(reduced_dims))

Expand Down