diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 868e9acc7..01eab1e6a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -19,6 +19,7 @@ New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. (:pull:`1885`). * Indicator parameters can now be assigned a new name, different from the argument name in the compute function. (:pull:`1885`). +* New global option ``resample_map_blocks`` to wrap all ``resample().map()`` code inside a ``xr.map_blocks`` to lower the number of dask tasks. Uses utility ``xclim.indices.helpers.resample_map`` and requires ``flox`` to ensure the chunking allows such block-mapping. Defaults to False. (:pull:`1848`). Bug fixes ^^^^^^^^^ diff --git a/pyproject.toml b/pyproject.toml index f330c40d5..d42db0b49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -180,7 +180,7 @@ pep621_dev_dependency_groups = ["all", "dev", "docs"] "POT" = "ot" [tool.deptry.per_rule_ignores] -DEP001 = ["SBCK"] +DEP001 = ["SBCK", "flox"] DEP002 = ["bottleneck", "h5netcdf", "pyarrow"] DEP004 = ["matplotlib", "pooch", "pytest", "pytest_socket"] diff --git a/tests/test_indices.py b/tests/test_indices.py index d1da9c54c..f33ce3ef2 100644 --- a/tests/test_indices.py +++ b/tests/test_indices.py @@ -29,6 +29,11 @@ from xclim.core.options import set_options from xclim.core.units import convert_units_to, units +try: + import flox +except ImportError: + flox = None + K2C = 273.15 @@ -1428,11 +1433,14 @@ def test_1d(self, tasmax_series, thresh, window, op, expected): def test_resampling_order(self, tasmax_series, resample_before_rl, expected): a = np.zeros(365) a[5:35] = 31 - tx = tasmax_series(a + K2C) + tx = tasmax_series(a + K2C).chunk() - hsf = xci.hot_spell_frequency( - tx, resample_before_rl=resample_before_rl, freq="MS" - ) + with set_options( + resample_map_blocks=(resample_before_rl and (flox is not None)) + ): + hsf = xci.hot_spell_frequency( + tx, resample_before_rl=resample_before_rl, freq="MS" + ).load() assert hsf[1] == expected @@ -1708,10 +1716,13 @@ def test_run_start_at_0(self, pr_series): def test_resampling_order(self, pr_series, resample_before_rl, expected): a = np.zeros(365) + 10 a[5:35] = 0 - pr = pr_series(a) - out = xci.maximum_consecutive_dry_days( - pr, freq="ME", resample_before_rl=resample_before_rl - ) + pr = pr_series(a).chunk() + with set_options( + resample_map_blocks=(resample_before_rl and (flox is not None)) + ): + out = xci.maximum_consecutive_dry_days( + pr, freq="ME", resample_before_rl=resample_before_rl + ).load() assert out[0] == expected diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 97242c692..2a3fb5765 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -165,6 +165,7 @@ infer_kind_from_parameter, is_percentile_dataarray, load_module, + split_auxiliary_coordinates, ) # Indicators registry @@ -1446,13 +1447,12 @@ def _postprocess(self, outs, das, params): # Reduce by or and broadcast to ensure the same length in time # When indexing is used and there are no valid points in the last period, mask will not include it mask = reduce(np.logical_or, miss) - if ( - isinstance(mask, DataArray) - and "time" in mask.dims - and mask.time.size < outs[0].time.size - ): - mask = mask.reindex(time=outs[0].time, fill_value=True) - outs = [out.where(np.logical_not(mask)) for out in outs] + if isinstance(mask, DataArray): # mask might be a bool in some cases + if "time" in mask.dims and mask.time.size < outs[0].time.size: + mask = mask.reindex(time=outs[0].time, fill_value=True) + # Remove any aux coord to avoid any unwanted dask computation in the alignment within "where" + mask, _ = split_auxiliary_coordinates(mask) + outs = [out.where(~mask) for out in outs] return outs diff --git a/xclim/core/options.py b/xclim/core/options.py index 90896f5d8..47cab12c5 100644 --- a/xclim/core/options.py +++ b/xclim/core/options.py @@ -25,6 +25,7 @@ SDBA_ENCODE_CF = "sdba_encode_cf" KEEP_ATTRS = "keep_attrs" AS_DATASET = "as_dataset" +MAP_BLOCKS = "resample_map_blocks" MISSING_METHODS: dict[str, Callable] = {} @@ -39,6 +40,7 @@ SDBA_ENCODE_CF: False, KEEP_ATTRS: "xarray", AS_DATASET: False, + MAP_BLOCKS: False, } _LOUDNESS_OPTIONS = frozenset(["log", "warn", "raise"]) @@ -71,6 +73,7 @@ def _valid_missing_options(mopts): SDBA_ENCODE_CF: lambda opt: isinstance(opt, bool), KEEP_ATTRS: _KEEP_ATTRS_OPTIONS.__contains__, AS_DATASET: lambda opt: isinstance(opt, bool), + MAP_BLOCKS: lambda opt: isinstance(opt, bool), } @@ -185,6 +188,9 @@ class set_options: Note that xarray's "default" is equivalent to False. Default: ``"xarray"``. as_dataset : bool If True, indicators output datasets. If False, they output DataArrays. Default :``False``. + resample_map_blocks: bool + If True, some indicators will wrap their resampling operations with `xr.map_blocks`, using :py:func:`xclim.indices.helpers.resample_map`. + This requires `flox` to be installed in order to ensure the chunking is appropriate.git Examples -------- diff --git a/xclim/core/utils.py b/xclim/core/utils.py index 5dff76c31..8b2f1991b 100644 --- a/xclim/core/utils.py +++ b/xclim/core/utils.py @@ -758,3 +758,43 @@ def _chunk_like(*inputs: xr.DataArray | xr.Dataset, chunks: dict[str, int] | Non da.chunk(**{d: c for d, c in chunks.items() if d in da.dims}) ) return tuple(outputs) + + +def split_auxiliary_coordinates( + obj: xr.DataArray | xr.Dataset, +) -> tuple[xr.DataArray | xr.Dataset, xr.Dataset]: + """Split auxiliary coords from the dataset. + + An auxiliary coordinate is a coordinate variable that does not define a dimension and thus is not necessarily needed for dataset alignment. + Any coordinate that has a name different than its dimension(s) is flagged as auxiliary. All scalar coordinates are flagged as auxiliary. + + Parameters + ---------- + obj : DataArray or Dataset + Xarray object + + Returns + ------- + clean_obj : + Same as `obj` but without any auxiliary coordinate. + aux_coords : Dataset + The auxiliary coordinates as a dataset. Might be empty. + + Note + ---- + This is useful to circumvent xarray's alignment checks that will sometimes look the auxiliary coordinate's data, which can trigger + unwanted dask computations. + + The auxiliary coordinates can be merged back with the dataset with + :py:meth:`xarray.Dataset.assign_coords` or :py:meth:`xarray.DataArray.assign_coords`. + + >>> clean, aux = split_auxiliary_coordinates(ds) + >>> merged = clean.assign_coords(da.coords) + >>> merged.identical(ds) # True + """ + aux_crd_names = [ + nm for nm, crd in obj.coords.items() if len(crd.dims) != 1 or crd.dims[0] != nm + ] + aux_crd_ds = obj.coords.to_dataset()[aux_crd_names] + clean_obj = obj.drop_vars(aux_crd_names) + return clean_obj, aux_crd_ds diff --git a/xclim/indices/_threshold.py b/xclim/indices/_threshold.py index 2e7d9ea8d..6ddfef709 100644 --- a/xclim/indices/_threshold.py +++ b/xclim/indices/_threshold.py @@ -29,6 +29,7 @@ spell_length_statistics, threshold_count, ) +from xclim.indices.helpers import resample_map # Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start # See http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases @@ -1491,12 +1492,17 @@ def last_spring_frost( thresh = convert_units_to(thresh, tasmin) cond = compare(tasmin, op, thresh, constrain=("<", "<=")) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.last_run_before_date, - window=window, - date=before_date, - dim="time", - coord="dayofyear", + map_kwargs=dict( + window=window, + date=before_date, + dim="time", + coord="dayofyear", + ), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(tasmin)) return out @@ -1662,11 +1668,12 @@ def first_snowfall( thresh = convert_units_to(thresh, prsn, context="hydro") cond = prsn >= thresh - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.first_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn)) return out @@ -1717,11 +1724,12 @@ def last_snowfall( thresh = convert_units_to(thresh, prsn, context="hydro") cond = prsn >= thresh - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.last_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn)) return out @@ -3097,7 +3105,7 @@ def _exceedance_date(grp): never_reached_val = never_reached return xarray.where((cumsum <= sum_thresh).all("time"), never_reached_val, out) - dded = c.clip(0).resample(time=freq).map(_exceedance_date) + dded = resample_map(c.clip(0), "time", freq, _exceedance_date) dded = dded.assign_attrs( units="", is_dayofyear=np.int32(1), calendar=get_calendar(tas) ) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index bc3d367de..89ceaf18d 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -27,6 +27,7 @@ to_agg_units, ) from xclim.indices import run_length as rl +from xclim.indices.helpers import resample_map __all__ = [ "aggregate_between_dates", @@ -90,14 +91,15 @@ def select_resample_op( The maximum value for each period. """ da = select_time(da, **indexer) - r = da.resample(time=freq) if isinstance(op, str): op = _xclim_ops.get(op, op) if isinstance(op, str): - out = getattr(r, op.replace("integral", "sum"))(dim="time", keep_attrs=True) + out = getattr(da.resample(time=freq), op.replace("integral", "sum"))( + dim="time", keep_attrs=True + ) else: with xr.set_options(keep_attrs=True): - out = r.map(op) + out = resample_map(da, "time", freq, op) op = op.__name__ if out_units is not None: return out.assign_attrs(units=out_units) @@ -734,7 +736,7 @@ def season( map_kwargs = {"window": window, "mid_date": mid_date} if stat in ["start", "end"]: map_kwargs["coord"] = "dayofyear" - out = cond.resample(time=freq).map(FUNC[stat], **map_kwargs) + out = resample_map(cond, "time", freq, FUNC[stat], map_kwargs=map_kwargs) if stat == "length": return to_agg_units(out, data, "count") # else, a date @@ -895,11 +897,12 @@ def first_occurrence( cond = compare(data, op, threshold, constrain) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.first_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs["units"] = "" return out @@ -940,11 +943,12 @@ def last_occurrence( cond = compare(data, op, threshold, constrain) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.last_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs["units"] = "" return out @@ -985,11 +989,12 @@ def spell_length( cond = compare(data, op, threshold) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.rle_statistics, - reducer=reducer, - window=1, - dim="time", + map_kwargs=dict(reducer=reducer, window=1, dim="time"), ) return to_agg_units(out, data, "count") @@ -1329,12 +1334,12 @@ def first_day_threshold_reached( cond = compare(data, op, threshold, constrain=constrain) - out: xr.DataArray = cond.resample(time=freq).map( + out: xr.DataArray = resample_map( + cond, + "time", + freq, rl.first_run_after_date, - window=window, - date=after_date, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=window, date=after_date, dim="time", coord="dayofyear"), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(data)) return out diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 2a172212c..cc1bc1471 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -2,14 +2,15 @@ Indices Helper Functions Submodule ================================== -Functions that encapsulate some geophysical logic but could be shared by many indices. +Functions that encapsulate logic and can be shared by many indices, +but are not particularly index-like themselves (those should go in the :py:mod:`xclim.indices.generic` module). """ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Callable, Mapping from inspect import stack -from typing import Any, cast +from typing import Any, Literal, cast import cf_xarray # noqa: F401, pylint: disable=unused-import import cftime @@ -28,8 +29,9 @@ from xclim.core import Quantified from xclim.core.calendar import ensure_cftime_array, get_calendar +from xclim.core.options import MAP_BLOCKS, OPTIONS from xclim.core.units import convert_units_to -from xclim.core.utils import _chunk_like +from xclim.core.utils import _chunk_like, uses_dask def _wrap_radians(da): @@ -559,3 +561,83 @@ def _gather_lon(da: xr.DataArray) -> xr.DataArray: "Try passing it explicitly (`lon=ds.lon`)." ) raise ValueError(msg) from err + + +def resample_map( + obj: xr.DataArray | xr.Dataset, + dim: str, + freq: str, + func: Callable, + map_blocks: bool | Literal["from_context"] = "from_context", + resample_kwargs: dict | None = None, + map_kwargs: dict | None = None, +) -> xr.DataArray | xr.Dataset: + r""" + Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. + + Parameters + ---------- + obj : DataArray or Dataset + The xarray object to resample. + dim : str + Dimension over which to resample. + freq : str + Resampling frequency along `dim`. + func : callable + Function to map on each resampled group. + map_blocks : bool or "from_context" + If True, the resample().map() call is wrapped inside a `map_blocks`. + If False, this does not do anything special. + If "from_context", xclim's "resample_map_blocks" option is used. + If the object is not using dask, this is set to False. + resample_kwargs : dict, optional + Other arguments to pass to `obj.resample()`. + map_kwargs : dict, optional + Arguments to pass to `map`. + + Returns + ------- + Resampled object. + """ + resample_kwargs = resample_kwargs or {} + map_kwargs = map_kwargs or {} + if map_blocks == "from_context": + map_blocks = OPTIONS[MAP_BLOCKS] + + if not uses_dask(obj) or not map_blocks: + return obj.resample({dim: freq}, **resample_kwargs).map(func, **map_kwargs) + + try: + from flox.xarray import rechunk_for_blockwise + except ImportError as err: + msg = f"Using {MAP_BLOCKS}=True requires flox." + raise ValueError(msg) from err + + # Make labels, a unique integer for each resample group + labels = xr.full_like(obj[dim], -1, dtype=np.int32) + for lbl, group_slice in enumerate(obj[dim].resample({dim: freq}).groups.values()): + labels[group_slice] = lbl + + obj_rechunked = rechunk_for_blockwise(obj, dim, labels) + + def _resample_map(obj_chnk, dm, frq, rs_kws, fun, mp_kws): + return obj_chnk.resample({dm: frq}, **rs_kws).map(fun, **mp_kws) + + # Template. We are hoping that this takes a negligeable time as it is never loaded. + template = obj_rechunked.resample(**{dim: freq}, **resample_kwargs).first() + + # New chunks along time : infer the number of elements resulting from the resampling of each chunk + if isinstance(obj_rechunked, xr.Dataset): + chunksizes = obj_rechunked.chunks[dim] + else: + chunksizes = obj_rechunked.chunks[obj_rechunked.get_axis_num(dim)] + new_chunks = [] + i = 0 + for chunksize in chunksizes: + new_chunks.append(len(np.unique(labels[i : i + chunksize]))) + i += chunksize + template = template.chunk({dim: tuple(new_chunks)}) + + return obj_rechunked.map_blocks( + _resample_map, (dim, freq, resample_kwargs, func, map_kwargs), template=template + ) diff --git a/xclim/indices/run_length.py b/xclim/indices/run_length.py index 3332343b1..b4d520a77 100644 --- a/xclim/indices/run_length.py +++ b/xclim/indices/run_length.py @@ -18,7 +18,8 @@ from xclim.core import DateStr, DayOfYearStr from xclim.core.options import OPTIONS, RUN_LENGTH_UFUNC -from xclim.core.utils import uses_dask +from xclim.core.utils import split_auxiliary_coordinates, uses_dask +from xclim.indices.helpers import resample_map npts_opt = 9000 """ @@ -111,8 +112,12 @@ def resample_and_rl( Output of compute resampled according to frequency {freq}. """ if resample_before_rl: - out = da.resample({dim: freq}).map( - compute, args=args, freq=None, dim=dim, **kwargs + out = resample_map( + da, + dim, + freq, + compute, + map_kwargs=dict(args=args, freq=None, dim=dim, **kwargs), ) else: out = compute(da, *args, dim=dim, freq=freq, **kwargs) @@ -254,8 +259,7 @@ def get_rl_stat(d): if freq is None: rl_stat = get_rl_stat(d) else: - rl_stat = d.resample({dim: freq}).map(get_rl_stat) - + rl_stat = resample_map(d, dim, freq, get_rl_stat) return rl_stat @@ -456,7 +460,6 @@ def coord_transform(out, da): crd = da[dim] if isinstance(coord, str): crd = getattr(crd.dt, coord) - out = lazy_indexing(crd, out) if dim in out.coords: @@ -481,7 +484,9 @@ def find_boundary_run(runs, position): da = da.fillna(0) # We expect a boolean array, but there could be NaNs nonetheless if window == 1: if freq is not None: - out = da.resample({dim: freq}).map(find_boundary_run, position=position) + out = resample_map( + da, dim, freq, find_boundary_run, map_kwargs=dict(position=position) + ) else: out = find_boundary_run(da, position) @@ -500,7 +505,9 @@ def find_boundary_run(runs, position): d = xr.where(d >= window, 1, 0) # for "first" run, return "first" element in the run (and conversely for "last" run) if freq is not None: - out = d.resample({dim: freq}).map(find_boundary_run, position=position) + out = resample_map( + d, dim, freq, find_boundary_run, map_kwargs=dict(position=position) + ) else: out = find_boundary_run(d, position) @@ -703,7 +710,7 @@ def get_out(rls): return out if freq is not None: - out = rls.resample({dim: freq}).map(get_out) + out = resample_map(rls, dim, freq, get_out) else: out = get_out(rls) @@ -867,8 +874,9 @@ def season( window: int, mid_date: DayOfYearStr | None = None, dim: str = "time", + stat: str | None = None, coord: str | bool | None = False, -) -> xr.Dataset: +) -> xr.Dataset | xr.DataArray: """Calculate the bounds of a season along a dimension. A "season" is a run of True values that may include breaks under a given length (`window`). @@ -1494,13 +1502,10 @@ def _index_from_1d_array(indices, array): # Renaming with no name to fix bug in xr 2024.01.0 tmpname = get_temp_dimname(da.dims, "temp") da2 = xr.DataArray(da.data, dims=(tmpname,), name=None) + # Map blocks chunks aux coords. Remove them to avoid the alignment check load in `where` + index, auxcrd = split_auxiliary_coordinates(index) # for each chunk of index, take corresponding values from da out = index.map_blocks(_index_from_1d_array, args=(da2,)).rename(da.name) - # Map blocks chunks aux coords. Replace them by non-chunked from the original array. - # This avoids unwanted loading of the aux coord in a resample.map, for example - for name, crd in out.coords.items(): - if uses_dask(crd) and name in index.coords and index[name].size == crd.size: - out = out.assign_coords(**{name: index[name]}) # mask where index was NaN. Drop any auxiliary coord, they are already on `out`. # Chunked aux coord would have the same name on both sides and xarray will want to check if they are equal, which means loading them # making lazy_indexing not lazy. same issue as above @@ -1509,6 +1514,7 @@ def _index_from_1d_array(indices, array): [crd for crd in invalid.coords if crd not in invalid.dims] ) ) + out = out.assign_coords(auxcrd.coords) if idx_ndim == 0: # 0-D case, drop useless coords and dummy dim out = out.drop_vars(da.dims[0], errors="ignore").squeeze()