Skip to content

Commit

Permalink
Raise when reduced dims are chunked in map_blocks (#1482)
Browse files Browse the repository at this point in the history
<!--Please ensure the PR fulfills the following requirements! -->
<!-- If this is your first PR, make sure to add your details to the
AUTHORS.rst! -->
### Pull Request Checklist:
- [x] This PR addresses an already opened issue (for bug fixes /
features)
    - This PR fixes #1481
- [x] Tests for the changes have been added (for bug fixes / features)
- [ ] (If applicable) Documentation has been added / updated (for bug
fixes / features)
- [x] CHANGES.rst has been updated (with summary of main changes)
- [x] Link to issue (:issue:`number`) and pull request (:pull:`number`)
has been added

### What kind of change does this PR introduce?
If a `map_blocks`-wrapped function receives input chunked along the
dimensions marked as "reduced", it raises an error.

This means one cannot chunk the training dataset (`QM.ds`, for example)
along the distribution ('quantiles') or group ('month', 'dayofyear')
dimensions.

Previously, this didn't raise an error, but the adjustment would be done
separately for each chunk, yielding incorrect results, silently.

### Does this PR introduce a breaking change?
No.

### Other information:
  • Loading branch information
aulemahal authored Sep 28, 2023
2 parents 0a48bbc + c846ca6 commit e9596e5
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
1 change: 1 addition & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ Bug fixes
^^^^^^^^^
* Fixed an error in the `pytest` configuration that prevented copying of testing data to thread-safe caches of workers under certain conditions (this should always occur). (:pull:`1473`).
* Coincidentally, this also fixes an error that caused `pytest` to error-out when invoked without an active internet connection. Running `pytest` without network access is now supported (requires cached testing data). (:issue:`1468`).
* Calling a ``sdba.map_blocks``-wrapped function with data chunked along the reduced dimensions will raise an error. This forbids chunking the trained dataset along the distribution dimensions, for example. (:issue:`1481`, :pull:`1482`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
13 changes: 13 additions & 0 deletions tests/test_sdba/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,3 +203,16 @@ def func(ds, *, dim):
).load()
assert set(data.data.dims) == {"dayofyear"}
assert "leftover" in data


def test_map_blocks_error(tas_series):
tas = tas_series(np.arange(366), start="2000-01-01")
tas = tas.expand_dims(lat=[1, 2, 3, 4]).chunk(lat=1)

# Test dim parsing
@map_blocks(reduces=["lat"], data=[])
def func(ds, *, group, lon=None):
return ds.tas.rename("data").to_dataset()

with pytest.raises(ValueError, match="cannot be chunked"):
func(xr.Dataset(dict(tas=tas)), group="time")
52 changes: 31 additions & 21 deletions xclim/sdba/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,7 +492,7 @@ def duck_empty(dims, sizes, dtype="float64", chunks=None):
def _decode_cf_coords(ds):
"""Decode coords in-place."""
crds = xr.decode_cf(ds.coords.to_dataset())
for crdname in ds.coords.keys():
for crdname in list(ds.coords.keys()):
ds[crdname] = crds[crdname]
# decode_cf introduces an encoding key for the dtype, which can confuse the netCDF writer
dtype = ds[crdname].encoding.get("dtype")
Expand Down Expand Up @@ -557,26 +557,6 @@ def _map_blocks(ds, **kwargs):
) and group is None:
raise ValueError("Missing required `group` argument.")

if uses_dask(ds):
# Use dask if any of the input is dask-backed.
chunks = (
dict(ds.chunks)
if isinstance(ds, xr.Dataset)
else dict(zip(ds.dims, ds.chunks))
)
if group is not None:
badchunks = {
dim: chunks.get(dim)
for dim in group.add_dims + [group.dim]
if len(chunks.get(dim, [])) > 1
}
if badchunks:
raise ValueError(
f"The dimension(s) over which we group cannot be chunked ({badchunks})."
)
else:
chunks = None

# Make translation dict
if group is not None:
placeholders = {
Expand All @@ -602,6 +582,36 @@ def _map_blocks(ds, **kwargs):
f"Dimension {dim} is meant to be added by the computation but it is already on one of the inputs."
)

if uses_dask(ds):
# Use dask if any of the input is dask-backed.
chunks = (
dict(ds.chunks)
if isinstance(ds, xr.Dataset)
else dict(zip(ds.dims, ds.chunks))
)
badchunks = {}
if group is not None:
badchunks.update(
{
dim: chunks.get(dim)
for dim in group.add_dims + [group.dim]
if len(chunks.get(dim, [])) > 1
}
)
badchunks.update(
{
dim: chunks.get(dim)
for dim in reduced_dims
if len(chunks.get(dim)) > 1
}
)
if badchunks:
raise ValueError(
f"The dimension(s) over which we group, reduce or interpolate cannot be chunked ({badchunks})."
)
else:
chunks = None

# Dimensions untouched by the function.
base_dims = list(set(ds.dims) - set(new_dims) - set(reduced_dims))

Expand Down

0 comments on commit e9596e5

Please sign in to comment.