From 30f28fec3a350858e9b52b5509d07beb7fdbe3b1 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 19 Jul 2024 13:31:15 -0400 Subject: [PATCH 01/25] Resample map helper --- xclim/core/options.py | 6 ++ xclim/indices/_threshold.py | 107 +++++++++++++++++++----------------- xclim/indices/generic.py | 47 +++++++++------- xclim/indices/helpers.py | 87 ++++++++++++++++++++++++++++- xclim/indices/run_length.py | 28 ++++++---- 5 files changed, 192 insertions(+), 83 deletions(-) diff --git a/xclim/core/options.py b/xclim/core/options.py index e4f78a255..5a796af64 100644 --- a/xclim/core/options.py +++ b/xclim/core/options.py @@ -25,6 +25,7 @@ SDBA_ENCODE_CF = "sdba_encode_cf" KEEP_ATTRS = "keep_attrs" AS_DATASET = "as_dataset" +MAP_BLOCKS = "resample_map_blocks" MISSING_METHODS: dict[str, Callable] = {} @@ -39,6 +40,7 @@ SDBA_ENCODE_CF: False, KEEP_ATTRS: "xarray", AS_DATASET: False, + MAP_BLOCKS: False, } _LOUDNESS_OPTIONS = frozenset(["log", "warn", "raise"]) @@ -71,6 +73,7 @@ def _valid_missing_options(mopts): SDBA_ENCODE_CF: lambda opt: isinstance(opt, bool), KEEP_ATTRS: _KEEP_ATTRS_OPTIONS.__contains__, AS_DATASET: lambda opt: isinstance(opt, bool), + MAP_BLOCKS: lambda opt: isinstance(opt, bool), } @@ -185,6 +188,9 @@ class set_options: Note that xarray's "default" is equivalent to False. Default: ``"xarray"``. as_dataset : bool If True, indicators output datasets. If False, they output DataArrays. Default :``False``. + resample_map_blocks: bool + If True, some indicators will wrap their resampling operations with `xr.map_blocks`, using :py:func:`xclim.indices.utils.resample_map`. + This requires `flox` to be installed in order to ensure the chunking is appropriate. Examples -------- diff --git a/xclim/indices/_threshold.py b/xclim/indices/_threshold.py index 1c25ae63e..d0b33cb66 100644 --- a/xclim/indices/_threshold.py +++ b/xclim/indices/_threshold.py @@ -30,6 +30,7 @@ spell_length_statistics, threshold_count, ) +from .helpers import resample_map # Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start # See http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases @@ -386,10 +387,12 @@ def snd_season_end( thresh = convert_units_to(thresh, snd) cond = snd >= thresh - resampled = ( - cond.resample(time=freq) - .map(rl.season, window=window, dim="time", coord="dayofyear") - .end + resampled = resample_map( + cond, + "time", + freq, + rl.season, + map_kwargs=dict(window=window, dim="time", stat="end", coord="dayofyear"), ) resampled = resampled.assign_attrs( units="", is_dayofyear=np.int32(1), calendar=get_calendar(snd) @@ -435,10 +438,12 @@ def snw_season_end( thresh = convert_units_to(thresh, snw) cond = snw >= thresh - resampled = ( - cond.resample(time=freq) - .map(rl.season, window=window, dim="time", coord="dayofyear") - .end + resampled = resample_map( + cond, + "time", + freq, + rl.season, + map_kwargs=dict(window=window, dim="time", stat="end", coord="dayofyear"), ) resampled.attrs.update( units="", is_dayofyear=np.int32(1), calendar=get_calendar(snw) @@ -484,15 +489,12 @@ def snd_season_start( thresh = convert_units_to(thresh, snd) cond = snd >= thresh - resampled = ( - cond.resample(time=freq) - .map( - rl.season, - window=window, - dim="time", - coord="dayofyear", - ) - .start + resampled = resample_map( + cond, + "time", + freq, + rl.season, + map_kwargs=dict(window=window, dim="time", stat="start", coord="dayofyear"), ) resampled.attrs.update( units="", is_dayofyear=np.int32(1), calendar=get_calendar(snd) @@ -539,15 +541,12 @@ def snw_season_start( thresh = convert_units_to(thresh, snw) cond = snw >= thresh - resampled = ( - cond.resample(time=freq) - .map( - rl.season, - window=window, - dim="time", - coord="dayofyear", - ) - .start + resampled = resample_map( + cond, + "time", + freq, + rl.season, + map_kwargs=dict(window=window, dim="time", stat="start", coord="dayofyear"), ) resampled.attrs.update( units="", is_dayofyear=np.int32(1), calendar=get_calendar(snw) @@ -592,11 +591,12 @@ def snd_season_length( thresh = convert_units_to(thresh, snd) cond = snd >= thresh - - snd_sl = ( - cond.resample(time=freq) - .map(rl.season, window=window, dim="time", coord="dayofyear") - .length + snd_sl = resample_map( + cond, + "time", + freq, + rl.season, + map_kwargs=dict(window=window, dim="time", stat="length", coord="dayofyear"), ) snd_sl = to_agg_units(snd_sl.where(~valid), snd, "count") return snd_sl @@ -639,10 +639,12 @@ def snw_season_length( thresh = convert_units_to(thresh, snw) cond = snw >= thresh - snw_sl = ( - cond.resample(time=freq) - .map(rl.season, window=window, dim="time", coord="dayofyear") - .length + snw_sl = resample_map( + cond, + "time", + freq, + rl.season, + map_kwargs=dict(window=window, dim="time", stat="length", coord="dayofyear"), ) snw_sl = to_agg_units(snw_sl.where(~valid), snw, "count") return snw_sl @@ -1558,12 +1560,17 @@ def last_spring_frost( thresh = convert_units_to(thresh, tasmin) cond = compare(tasmin, op, thresh, constrain=("<", "<=")) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.last_run_before_date, - window=window, - date=before_date, - dim="time", - coord="dayofyear", + map_kwargs=dict( + window=window, + date=before_date, + dim="time", + coord="dayofyear", + ), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(tasmin)) return out @@ -1729,11 +1736,12 @@ def first_snowfall( thresh = convert_units_to(thresh, prsn, context="hydro") cond = prsn >= thresh - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.first_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn)) return out @@ -1784,11 +1792,12 @@ def last_snowfall( thresh = convert_units_to(thresh, prsn, context="hydro") cond = prsn >= thresh - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.last_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(prsn)) return out @@ -3163,7 +3172,7 @@ def _exceedance_date(grp): ) return xarray.where((cumsum <= sum_thresh).all("time"), never_reached_val, out) - dded = c.clip(0).resample(time=freq).map(_exceedance_date) + dded = resample_map(c.clip(0), "time", freq, _exceedance_date) dded = dded.assign_attrs( units="", is_dayofyear=np.int32(1), calendar=get_calendar(tas) ) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 970e24297..1bc79d1f1 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -29,6 +29,7 @@ from xclim.core.utils import DayOfYearStr, Quantified, Quantity from . import run_length as rl +from .helpers import resample_map __all__ = [ "aggregate_between_dates", @@ -88,14 +89,15 @@ def select_resample_op( The maximum value for each period. """ da = select_time(da, **indexer) - r = da.resample(time=freq) if op in _xclim_ops: op = _xclim_ops[op] if isinstance(op, str): - out = getattr(r, op.replace("integral", "sum"))(dim="time", keep_attrs=True) + out = getattr(da.resample(time=freq), op.replace("integral", "sum"))( + dim="time", keep_attrs=True + ) else: with xr.set_options(keep_attrs=True): - out = r.map(op) + out = resample_map(da, "time", freq, op) op = op.__name__ if out_units is not None: return out.assign_attrs(units=out_units) @@ -544,7 +546,7 @@ def season( map_kwargs = dict(window=window, date=mid_date) if stat in ["start", "end"]: map_kwargs["coord"] = "dayofyear" - out = cond.resample(time=freq).map(FUNC[stat], **map_kwargs) + out = resample_map(cond, "time", freq, FUNC[stat], map_kwargs=map_kwargs) if stat == "length": return to_agg_units(out, data, "count") # else, a date @@ -705,11 +707,12 @@ def first_occurrence( cond = compare(data, op, threshold, constrain) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.first_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs["units"] = "" return out @@ -750,11 +753,12 @@ def last_occurrence( cond = compare(data, op, threshold, constrain) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.last_run, - window=1, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=1, dim="time", coord="dayofyear"), ) out.attrs["units"] = "" return out @@ -795,11 +799,12 @@ def spell_length( cond = compare(data, op, threshold) - out = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.rle_statistics, - reducer=reducer, - window=1, - dim="time", + map_kwargs=dict(reducer=reducer, window=1, dim="time"), ) return to_agg_units(out, data, "count") @@ -1139,12 +1144,12 @@ def first_day_threshold_reached( cond = compare(data, op, threshold, constrain=constrain) - out: xarray.DataArray = cond.resample(time=freq).map( + out = resample_map( + cond, + "time", + freq, rl.first_run_after_date, - window=window, - date=after_date, - dim="time", - coord="dayofyear", + map_kwargs=dict(window=window, date=after_date, dim="time", coord="dayofyear"), ) out.attrs.update(units="", is_dayofyear=np.int32(1), calendar=get_calendar(data)) return out diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 057c39305..66019ac07 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -2,14 +2,15 @@ Indices Helper Functions Submodule ================================== -Functions that encapsulate some geophysical logic but could be shared by many indices. +Functions that encapsulate logic but could be shared by many indices, +but are not index-like themselves (these should go in the :py:mod:`xclim.indices.generic` module). """ from __future__ import annotations from collections.abc import Mapping from inspect import stack -from typing import Any +from typing import Any, Callable, Literal import cf_xarray # noqa: F401, pylint: disable=unused-import import cftime @@ -21,8 +22,9 @@ ) from xclim.core.calendar import ensure_cftime_array, get_calendar +from xclim.core.options import MAP_BLOCKS, OPTIONS from xclim.core.units import convert_units_to -from xclim.core.utils import Quantified, _chunk_like +from xclim.core.utils import Quantified, _chunk_like, uses_dask def _wrap_radians(da): @@ -540,3 +542,82 @@ def _gather_lon(da: xr.DataArray) -> xr.DataArray: "Try passing it explicitly (`lon=ds.lon`)." ) raise ValueError(msg) from err + + +def resample_map( + obj: xr.DataArray | xr.Dataset, + dim: str, + freq: str, + func: Callable, + map_blocks: bool | Literal["from_context"] = "from_context", + resample_kwargs: dict | None = None, + map_kwargs: dict | None = None, +) -> xr.DataArray | xr.Dataset: + r""" + Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. + + Parameters + ---------- + obj: DataArray or Dataset + The xarray object to resample. + dim: str + Dimension over which to resample. + freq : str + Resampling frequency along `dim`. + func : callable + Function to map on each resampled group. + map_blocks: bool or "from_context" + If True, the resample().map() call is wrapped inside a `map_blocks`. + If False, this does not do anything special. + If "from_context", xclim's "resample_map_blocks" option is used. + If the object is not using dask, this is set to False. + resample_kwargs: dict, optional + Other arguments to pass to `obj.resample()`. + map_kwargs: dict, optional + Arguments to pass to `map`. + + Returns + ------- + Resampled object. + """ + resample_kwargs = resample_kwargs or {} + map_kwargs = map_kwargs or {} + if map_blocks == "from_context": + map_blocks = OPTIONS[MAP_BLOCKS] + + if not uses_dask(obj) or not map_blocks: + return obj.resample({dim: freq}, **resample_kwargs).map(func, **map_kwargs) + + try: + from flox.xarray import rechunk_for_blockwise + except ImportError as err: + raise ValueError(f"Using {MAP_BLOCKS}=True requires flox.") from err + + # Make labels, a unique integer for each resample group + labels = xr.full_like(obj[dim], -1, dtype=np.int32) + for lbl, group_slice in enumerate(obj[dim].resample({dim: freq}).groups.values()): + labels[group_slice] = lbl + + obj_rechunked = rechunk_for_blockwise(obj, dim, labels) + + def _resample_map(obj_chnk, dm, frq, rs_kws, fun, mp_kws): + return obj_chnk.resample({dm: frq}, **rs_kws).map(fun, **mp_kws) + + # Template. We are hoping that this takes a negligeable time as it is never loaded. + template = obj_rechunked.resample(**{dim: freq}, **resample_kwargs).first() + + # New chunks along time : infer the number of elements resulting from the resampling of each chunk + if isinstance(obj_rechunked, xr.Dataset): + chunksizes = obj_rechunked.chunks[dim] + else: + chunksizes = obj_rechunked.chunks[obj_rechunked.get_axis_num(dim)] + new_chunks = [] + i = 0 + for chunksize in chunksizes: + new_chunks.append(len(np.unique(labels[i : i + chunksize]))) + i += chunksize + template = template.chunk({dim: tuple(new_chunks)}) + + return obj_rechunked.map_blocks( + _resample_map, (dim, freq, resample_kwargs, func, map_kwargs), template=template + ) diff --git a/xclim/indices/run_length.py b/xclim/indices/run_length.py index ba972695a..013783a71 100644 --- a/xclim/indices/run_length.py +++ b/xclim/indices/run_length.py @@ -18,6 +18,7 @@ from xclim.core.options import OPTIONS, RUN_LENGTH_UFUNC from xclim.core.utils import DateStr, DayOfYearStr, uses_dask +from xclim.indices.helpers import resample_map npts_opt = 9000 """ @@ -110,8 +111,12 @@ def resample_and_rl( Output of compute resampled according to frequency {freq}. """ if resample_before_rl: - out = da.resample({dim: freq}).map( - compute, args=args, freq=None, dim=dim, **kwargs + out = resample_map( + da, + dim, + freq, + compute, + map_kwargs=dict(args=args, freq=None, dim=dim, **kwargs), ) else: out = compute(da, *args, dim=dim, freq=freq, **kwargs) @@ -253,8 +258,7 @@ def get_rl_stat(d): if freq is None: rl_stat = get_rl_stat(d) else: - rl_stat = d.resample({dim: freq}).map(get_rl_stat) - + rl_stat = resample_map(d, dim, freq, get_rl_stat) return rl_stat @@ -480,7 +484,9 @@ def find_boundary_run(runs, position): da = da.fillna(0) # We expect a boolean array, but there could be NaNs nonetheless if window == 1: if freq is not None: - out = da.resample({dim: freq}).map(find_boundary_run, position=position) + out = resample_map( + da, dim, freq, find_boundary_run, map_kwargs=dict(position=position) + ) else: out = find_boundary_run(da, position) @@ -499,7 +505,9 @@ def find_boundary_run(runs, position): d = xr.where(d >= window, 1, 0) # for "first" run, return "first" element in the run (and conversely for "last" run) if freq is not None: - out = d.resample({dim: freq}).map(find_boundary_run, position=position) + out = resample_map( + d, dim, freq, find_boundary_run, map_kwargs=dict(position=position) + ) else: out = find_boundary_run(d, position) @@ -702,7 +710,7 @@ def get_out(rls): return out if freq is not None: - out = rls.resample({dim: freq}).map(get_out) + out = resample_map(rls, dim, freq, get_out) else: out = get_out(rls) @@ -862,8 +870,9 @@ def season( window: int, date: DayOfYearStr | None = None, dim: str = "time", + stat: str | None = None, coord: str | bool | None = False, -) -> xr.Dataset: +) -> xr.Dataset | xr.DataArray: """Calculate the bounds of a season along a dimension. A "season" is a run of True values that may include breaks under a given length (`window`). @@ -993,8 +1002,7 @@ def season_length( season_start season_end """ - seas = season(da, window, date, dim, coord=False) - return seas.length + return season(da, window, date, dim, coord=False).length def run_end_after_date( From 065b4f31fd7494eadf8b044b44a1ee151f3375c7 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 23 Jul 2024 14:49:25 -0400 Subject: [PATCH 02/25] New split_aux_coord func to remove aux coord and avoid dask comp on alignment --- xclim/core/indicator.py | 3 +++ xclim/core/utils.py | 39 +++++++++++++++++++++++++++++++++++++ xclim/indices/run_length.py | 14 +++++++------ 3 files changed, 50 insertions(+), 6 deletions(-) diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 74777dbe9..71ca81897 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -162,6 +162,7 @@ is_percentile_dataarray, load_module, raise_warn_or_log, + split_auxiliary_coordinates, ) # Indicators registry @@ -1456,6 +1457,8 @@ def _postprocess(self, outs, das, params): and mask.time.size < outs[0].time.size ): mask = mask.reindex(time=outs[0].time, fill_value=True) + # Remove any aux coord to avoid any unwanted dask computation in the alignment within "where" + mask, _ = split_auxiliary_coordinates(mask) outs = [out.where(~mask) for out in outs] return outs diff --git a/xclim/core/utils.py b/xclim/core/utils.py index 25fc7b985..35f5ee1dd 100644 --- a/xclim/core/utils.py +++ b/xclim/core/utils.py @@ -862,3 +862,42 @@ def _chunk_like(*inputs: xr.DataArray | xr.Dataset, chunks: dict[str, int] | Non da.chunk(**{d: c for d, c in chunks.items() if d in da.dims}) ) return tuple(outputs) + + +def split_auxiliary_coordinates( + obj: xr.DataArray | xr.Dataset, +) -> tuple[xr.DataArray | xr.Dataset, xr.Dataset]: + """Split auxiliary coords from the dataset. + + An auxiliary coordinate is a coordinate variable that does not define a dimension and thus is not necessarily needed for dataset alignment. + Any coordinate that has a name different than its dimension(s) is flagged as auxiliary. All scalar coordinates are flagged as auxiliary. + + Parameters + ---------- + obj : DataArray or Dataset + Xarray object + + Returns + ------- + clean_obj : + Same as `obj` but without any auxiliary coordinate. + aux_coords : Dataset + The auxiliary coordinates as a dataset. Might be empty. + + Note + ---- + This is useful to circumvent xarray's alignment checks that will sometimes look the auxiliary coordinate's data, which can trigger + unwanted dask computations. + + The auxiliary coordinates can be merged back with the dataset with :py:func:`xarray.merge`. + + >>> clean, aux = split_auxiliary_coordinates(ds) + >>> merged = xr.merge([clean, aux]) + >>> merged.identical(ds) # True + """ + aux_crd_names = [ + nm for nm, crd in obj.coords.items() if len(crd.dims) != 1 or crd.dims[0] != nm + ] + aux_crd_ds = obj.coords.to_dataset()[aux_crd_names] + clean_obj = obj.drop_vars(aux_crd_names) + return clean_obj, aux_crd_ds diff --git a/xclim/indices/run_length.py b/xclim/indices/run_length.py index 013783a71..2f7a12888 100644 --- a/xclim/indices/run_length.py +++ b/xclim/indices/run_length.py @@ -17,7 +17,12 @@ from xarray.core.utils import get_temp_dimname from xclim.core.options import OPTIONS, RUN_LENGTH_UFUNC -from xclim.core.utils import DateStr, DayOfYearStr, uses_dask +from xclim.core.utils import ( + DateStr, + DayOfYearStr, + split_auxiliary_coordinates, + uses_dask, +) from xclim.indices.helpers import resample_map npts_opt = 9000 @@ -1496,11 +1501,8 @@ def _index_from_1d_array(indices, array): da2 = xr.DataArray(da.data, dims=(tmpname,), name=None) # for each chunk of index, take corresponding values from da out = index.map_blocks(_index_from_1d_array, args=(da2,)).rename(da.name) - # Map blocks chunks aux coords. Replace them by non-chunked from the original array. - # This avoids unwanted loading of the aux coord in a resample.map, for example - for name, crd in out.coords.items(): - if uses_dask(crd) and name in index.coords and index[name].size == crd.size: - out = out.assign_coords(**{name: index[name]}) + # Map blocks chunks aux coords. Remove them to avoid the alignment check load in `where` + out, _ = split_auxiliary_coordinates(out) # mask where index was NaN. Drop any auxiliary coord, they are already on `out`. # Chunked aux coord would have the same name on both sides and xarray will want to check if they are equal, which means loading them # making lazy_indexing not lazy. same issue as above From 0c8c797f192e823ef0ff4e323b9023dec0684465 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 23 Jul 2024 14:54:44 -0400 Subject: [PATCH 03/25] Ignore missing flox dep --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 549670960..1f129cb62 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -177,7 +177,7 @@ pep621_dev_dependency_groups = ["all", "dev", "docs"] "pyyaml" = "yaml" [tool.deptry.per_rule_ignores] -DEP001 = ["SBCK"] +DEP001 = ["SBCK", "flox"] DEP002 = ["bottleneck", "pyarrow"] DEP004 = ["matplotlib", "pytest_socket"] From 9c81ed87381e17c2033d67dfa0faff0a68d550e8 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 23 Jul 2024 15:08:31 -0400 Subject: [PATCH 04/25] Fix for bool mask --- xclim/core/indicator.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/xclim/core/indicator.py b/xclim/core/indicator.py index 71ca81897..ada2e7839 100644 --- a/xclim/core/indicator.py +++ b/xclim/core/indicator.py @@ -1451,14 +1451,11 @@ def _postprocess(self, outs, das, params): # Reduce by or and broadcast to ensure the same length in time # When indexing is used and there are no valid points in the last period, mask will not include it mask = reduce(np.logical_or, miss) - if ( - isinstance(mask, DataArray) - and "time" in mask.dims - and mask.time.size < outs[0].time.size - ): - mask = mask.reindex(time=outs[0].time, fill_value=True) - # Remove any aux coord to avoid any unwanted dask computation in the alignment within "where" - mask, _ = split_auxiliary_coordinates(mask) + if isinstance(mask, DataArray): # mask might be a bool in some cases + if "time" in mask.dims and mask.time.size < outs[0].time.size: + mask = mask.reindex(time=outs[0].time, fill_value=True) + # Remove any aux coord to avoid any unwanted dask computation in the alignment within "where" + mask, _ = split_auxiliary_coordinates(mask) outs = [out.where(~mask) for out in outs] return outs From 764d67b7e9f626d277b39e50f0068566fabe6bba Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 24 Jul 2024 09:36:45 -0400 Subject: [PATCH 05/25] Fix aux coord mngmt in lazy indexing - fix doc split aux coord --- xclim/core/utils.py | 5 +++-- xclim/indices/run_length.py | 4 ++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/xclim/core/utils.py b/xclim/core/utils.py index 35f5ee1dd..8a133b5c4 100644 --- a/xclim/core/utils.py +++ b/xclim/core/utils.py @@ -889,10 +889,11 @@ def split_auxiliary_coordinates( This is useful to circumvent xarray's alignment checks that will sometimes look the auxiliary coordinate's data, which can trigger unwanted dask computations. - The auxiliary coordinates can be merged back with the dataset with :py:func:`xarray.merge`. + The auxiliary coordinates can be merged back with the dataset with + :py:meth:`xarray.Dataset.assign_coords` or :py:meth:`xarray.DataArray.assign_coords`. >>> clean, aux = split_auxiliary_coordinates(ds) - >>> merged = xr.merge([clean, aux]) + >>> merged = clean.assign_coords(da.coords) >>> merged.identical(ds) # True """ aux_crd_names = [ diff --git a/xclim/indices/run_length.py b/xclim/indices/run_length.py index 2f7a12888..99bddc10c 100644 --- a/xclim/indices/run_length.py +++ b/xclim/indices/run_length.py @@ -464,7 +464,6 @@ def coord_transform(out, da): crd = da[dim] if isinstance(coord, str): crd = getattr(crd.dt, coord) - out = lazy_indexing(crd, out) if dim in out.coords: @@ -1502,7 +1501,7 @@ def _index_from_1d_array(indices, array): # for each chunk of index, take corresponding values from da out = index.map_blocks(_index_from_1d_array, args=(da2,)).rename(da.name) # Map blocks chunks aux coords. Remove them to avoid the alignment check load in `where` - out, _ = split_auxiliary_coordinates(out) + out, auxcrd = split_auxiliary_coordinates(out) # mask where index was NaN. Drop any auxiliary coord, they are already on `out`. # Chunked aux coord would have the same name on both sides and xarray will want to check if they are equal, which means loading them # making lazy_indexing not lazy. same issue as above @@ -1511,6 +1510,7 @@ def _index_from_1d_array(indices, array): [crd for crd in invalid.coords if crd not in invalid.dims] ) ) + out = out.assign_coords(auxcrd.coords) if idx_ndim == 0: # 0-D case, drop useless coords and dummy dim out = out.drop_vars(da.dims[0], errors="ignore").squeeze() From 3d4c457d0c6272c9918358d4919c2543cdea09c4 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 24 Jul 2024 10:01:41 -0400 Subject: [PATCH 06/25] lower pin of flit --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 1f129cb62..724840111 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["flit_core >=3.9,<4"] +requires = ["flit_core >=3.8,<4"] build-backend = "flit_core.buildapi" [project] From 3adbf763295de1d05d5cc0501710a2be284c32ca Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 26 Jul 2024 13:27:55 -0400 Subject: [PATCH 07/25] fix a fix that didnt fix what needed to be fixed --- xclim/indices/run_length.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xclim/indices/run_length.py b/xclim/indices/run_length.py index 99bddc10c..a955f69c8 100644 --- a/xclim/indices/run_length.py +++ b/xclim/indices/run_length.py @@ -1498,10 +1498,10 @@ def _index_from_1d_array(indices, array): # Renaming with no name to fix bug in xr 2024.01.0 tmpname = get_temp_dimname(da.dims, "temp") da2 = xr.DataArray(da.data, dims=(tmpname,), name=None) + # Map blocks chunks aux coords. Remove them to avoid the alignment check load in `where` + index, auxcrd = split_auxiliary_coordinates(index) # for each chunk of index, take corresponding values from da out = index.map_blocks(_index_from_1d_array, args=(da2,)).rename(da.name) - # Map blocks chunks aux coords. Remove them to avoid the alignment check load in `where` - out, auxcrd = split_auxiliary_coordinates(out) # mask where index was NaN. Drop any auxiliary coord, they are already on `out`. # Chunked aux coord would have the same name on both sides and xarray will want to check if they are equal, which means loading them # making lazy_indexing not lazy. same issue as above From 9778739ccc60bd13ef2550fcab6722b66c0a38cf Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 16 Aug 2024 12:16:53 -0400 Subject: [PATCH 08/25] Resample before spells --- xclim/indices/generic.py | 127 +++++++++++++++++++++------------------ xclim/indices/helpers.py | 5 ++ 2 files changed, 75 insertions(+), 57 deletions(-) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 2e748e770..858c3bf08 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -359,13 +359,13 @@ def get_daily_events( def spell_mask( - data: xarray.DataArray | Sequence[xarray.DataArray], + data: xarray.DataArray, window: int, win_reducer: str, op: str, - thresh: float | Sequence[float], + thresh: float | xarray.DataArray, weights: Sequence[float] = None, - var_reducer: str = "all", + var_reducer: str | None = None, ) -> xarray.DataArray: """Compute the boolean mask of data points that are part of a spell as defined by a rolling statistic. @@ -373,41 +373,40 @@ def spell_mask( Parameters ---------- - data: DataArray or sequence of DataArray - The input data. Can be a list, in which case the condition is checked on all variables. - See var_reducer for the latter case. + data: DataArray + The input data. Considered multivariate if `var_reducer` is given and data has a "variable" dimension. + See `var_reducer` and `threshold` for the later case. window: int The length of the rolling window in which to compute statistics. win_reducer: {'min', 'max', 'sum', 'mean'} The statistics to compute on the rolling window. op: {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} The comparison operator to use when finding spells. - thresh: float or sequence of floats + thresh: float or DataArray The threshold to compare the rolling statistics against, as ``window_stats op threshold``. - If data is a list, this must be a list of the same length with a threshold for each variable. - This function does not handle units and can't accept Quantified objects. + If data is multivariate, this must be a DataArray with a "variable" dimension of the same length, with a threshold for each variable. + This function does not handle units. weights: sequence of floats A list of weights of the same length as the window. Only supported if `win_reducer` is "mean". - var_reducer: {'all', 'any'} - If the data is a list, the condition must either be fulfilled on *all* + var_reducer: {'all', 'any'}, optional + If the data is multivariate, the condition must either be fulfilled on *all* or *any* variables for the period to be considered a spell. + If None (default), the data is not considered multivariate. Returns ------- xarray.DataArray Same shape as ``data``, but boolean. - If ``data`` was a list, this is a DataArray of the same shape as the alignment of all variables. + If ``data`` was multivariate, the variable dimension has been removed. """ # Checks - if not isinstance(data, xarray.DataArray): - # thus a sequence - if np.isscalar(thresh) or len(data) != len(thresh): - raise ValueError( - "When ``data`` is given as a list, ``threshold`` must be a sequence of the same length." - ) - data = xarray.concat(data, "variable") - thresh = xarray.DataArray(thresh, dims=("variable",)) + multivariate = False + if var_reducer is not None: + if 'variable' in data.dims and 'variable' in thresh.dims: + multivariate = True + else: + raise ValueError("'var_reducer' was given but the data does not have 'variable' dimension.") if weights is not None: if win_reducer != "mean": raise ValueError( @@ -421,7 +420,7 @@ def spell_mask( if window == 1: # Fast path mask = compare(data, op, thresh) - if not np.isscalar(thresh): + if multivariate: is_in_spell = getattr(mask, var_reducer)("variable") elif (win_reducer == "min" and op in [">", ">=", "ge", "gt"]) or ( win_reducer == "max" and op in ["`<", "<=", "le", "lt"] @@ -429,7 +428,7 @@ def spell_mask( # Fast path for specific cases, this yields a smaller dask graph (rolling twice is expensive!) # For these two cases, a day can't be part of a spell if it doesn't respect the condition itself mask = compare(data, op, thresh) - if not np.isscalar(thresh): + if multivariate: mask = getattr(mask, var_reducer)("variable") # We need to filter out the spells shorter than "window" # find sequences of consecutive respected constraints @@ -449,7 +448,7 @@ def spell_mask( spell_value = getattr(data_pad.rolling(time=window), win_reducer)() # True at the end of a spell respecting the condition mask = compare(spell_value, op, thresh) - if not np.isscalar(thresh): + if multivariate: mask = getattr(mask, var_reducer)("variable") # True for all days part of a spell that respected the condition (shift because of the two rollings) is_in_spell = (mask.rolling(time=window).sum() >= 1).shift(time=-(window - 1)) @@ -467,7 +466,7 @@ def spell_length_statistics( op: str, spell_reducer: str, freq: str, - resample_before_rl: bool = True, + resample_before_spells: bool = False, **indexer, ): r"""Statistics on spells lengths. @@ -492,9 +491,9 @@ def spell_length_statistics( Statistic on the spell lengths. freq : str Resampling frequency. - resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + resample_before_spells : bool + Determines if the resampling should take place before or after finding the + spells. If it takes place before, spells cannot cross the period boundary. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`. @@ -535,18 +534,23 @@ def spell_length_statistics( bivariate_spell_length_statistics : The bivariate version of this function. """ thresh = convert_units_to(threshold, data, context="infer") - is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - out = rl.resample_and_rl( - is_in_spell, - resample_before_rl, - rl.rle_statistics, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=freq, - ) + def _spell(da, frq = None): + is_in_spell = spell_mask(da, window, win_reducer, op, thresh).astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) + + return rl.rle_statistics( + is_in_spell, + reducer=spell_reducer, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=frq, + ) + + if resample_before_spells: + out = resample_map(data, 'time', freq, _spell) + else: + out = _spell(data, freq) if spell_reducer == "count": return out.assign_attrs(units="") @@ -565,7 +569,7 @@ def bivariate_spell_length_statistics( op: str, spell_reducer: str, freq: str, - resample_before_rl: bool = True, + resample_before_spells: bool = True, **indexer, ): r"""Statistics on spells lengths based on two variables. @@ -594,9 +598,9 @@ def bivariate_spell_length_statistics( Statistic on the spell lengths. freq : str Resampling frequency. - resample_before_rl : bool - Determines if the resampling should take place before or after the run - length encoding (or a similar algorithm) is applied to runs. + resample_before_spells : bool + Determines if the resampling should take place before or after finding the + spells. If it takes place before, spells cannot cross the period boundary. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`. @@ -609,25 +613,34 @@ def bivariate_spell_length_statistics( """ thresh1 = convert_units_to(threshold1, data1, context="infer") thresh2 = convert_units_to(threshold2, data2, context="infer") - is_in_spell = spell_mask( - [data1, data2], window, win_reducer, op, [thresh1, thresh2], var_reducer="all" - ).astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - - out = rl.resample_and_rl( - is_in_spell, - resample_before_rl, - rl.rle_statistics, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=freq, - ) + + data = xr.concat([data2, data2], 'variable') + if isinstance(thresh1, xr.DataArray): + thresh = xr.concat([thresh1, thresh2], 'variable') + else: + thresh = xr.DataArray([thresh1, thresh2], dims=('variable',)) + + def _spell(da, frq = None): + is_in_spell = spell_mask(da, window, win_reducer, op, thresh, var_reducer='all').astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) + + return rl.rle_statistics( + is_in_spell, + reducer=spell_reducer, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=frq, + ) + + if resample_before_spells: + out = resample_map(data, 'time', freq, _spell, reduced_dims=['variable']) + else: + out = _spell(data, freq) if spell_reducer == "count": return out.assign_attrs(units="") # All other cases are statistics of the number of timesteps - return to_agg_units(out, data1, "count") + return to_agg_units(out, data, "count") @declare_relative_units(thresh="") diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 66019ac07..99b6cb8cc 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -552,6 +552,7 @@ def resample_map( map_blocks: bool | Literal["from_context"] = "from_context", resample_kwargs: dict | None = None, map_kwargs: dict | None = None, + reduced_dims: Sequence[str] | None = None ) -> xr.DataArray | xr.Dataset: r""" Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. @@ -575,6 +576,8 @@ def resample_map( Other arguments to pass to `obj.resample()`. map_kwargs: dict, optional Arguments to pass to `map`. + reduced_dims: sequence of strings, optional + A list of dims on `obj` that will be reduced (removed) by the mapped function. Returns ------- @@ -605,6 +608,8 @@ def _resample_map(obj_chnk, dm, frq, rs_kws, fun, mp_kws): # Template. We are hoping that this takes a negligeable time as it is never loaded. template = obj_rechunked.resample(**{dim: freq}, **resample_kwargs).first() + if reduced_dims: # Removed reduced dims + template = template.isel({d: 0 for d in reduced_dims}, drop=True) # New chunks along time : infer the number of elements resulting from the resampling of each chunk if isinstance(obj_rechunked, xr.Dataset): From a7e5bde82a37f56ba06aee1f402bb3209299f33e Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 16 Aug 2024 16:18:01 +0000 Subject: [PATCH 09/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xclim/indices/generic.py | 28 ++++++++++++++++------------ xclim/indices/helpers.py | 4 ++-- 2 files changed, 18 insertions(+), 14 deletions(-) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 858c3bf08..f89af8662 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -403,10 +403,12 @@ def spell_mask( # Checks multivariate = False if var_reducer is not None: - if 'variable' in data.dims and 'variable' in thresh.dims: + if "variable" in data.dims and "variable" in thresh.dims: multivariate = True else: - raise ValueError("'var_reducer' was given but the data does not have 'variable' dimension.") + raise ValueError( + "'var_reducer' was given but the data does not have 'variable' dimension." + ) if weights is not None: if win_reducer != "mean": raise ValueError( @@ -492,7 +494,7 @@ def spell_length_statistics( freq : str Resampling frequency. resample_before_spells : bool - Determines if the resampling should take place before or after finding the + Determines if the resampling should take place before or after finding the spells. If it takes place before, spells cannot cross the period boundary. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. @@ -535,7 +537,7 @@ def spell_length_statistics( """ thresh = convert_units_to(threshold, data, context="infer") - def _spell(da, frq = None): + def _spell(da, frq=None): is_in_spell = spell_mask(da, window, win_reducer, op, thresh).astype(np.float32) is_in_spell = select_time(is_in_spell, **indexer) @@ -548,7 +550,7 @@ def _spell(da, frq = None): ) if resample_before_spells: - out = resample_map(data, 'time', freq, _spell) + out = resample_map(data, "time", freq, _spell) else: out = _spell(data, freq) @@ -599,7 +601,7 @@ def bivariate_spell_length_statistics( freq : str Resampling frequency. resample_before_spells : bool - Determines if the resampling should take place before or after finding the + Determines if the resampling should take place before or after finding the spells. If it takes place before, spells cannot cross the period boundary. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. @@ -614,14 +616,16 @@ def bivariate_spell_length_statistics( thresh1 = convert_units_to(threshold1, data1, context="infer") thresh2 = convert_units_to(threshold2, data2, context="infer") - data = xr.concat([data2, data2], 'variable') + data = xr.concat([data2, data2], "variable") if isinstance(thresh1, xr.DataArray): - thresh = xr.concat([thresh1, thresh2], 'variable') + thresh = xr.concat([thresh1, thresh2], "variable") else: - thresh = xr.DataArray([thresh1, thresh2], dims=('variable',)) + thresh = xr.DataArray([thresh1, thresh2], dims=("variable",)) - def _spell(da, frq = None): - is_in_spell = spell_mask(da, window, win_reducer, op, thresh, var_reducer='all').astype(np.float32) + def _spell(da, frq=None): + is_in_spell = spell_mask( + da, window, win_reducer, op, thresh, var_reducer="all" + ).astype(np.float32) is_in_spell = select_time(is_in_spell, **indexer) return rl.rle_statistics( @@ -633,7 +637,7 @@ def _spell(da, frq = None): ) if resample_before_spells: - out = resample_map(data, 'time', freq, _spell, reduced_dims=['variable']) + out = resample_map(data, "time", freq, _spell, reduced_dims=["variable"]) else: out = _spell(data, freq) diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 99b6cb8cc..8b0d64de5 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -552,7 +552,7 @@ def resample_map( map_blocks: bool | Literal["from_context"] = "from_context", resample_kwargs: dict | None = None, map_kwargs: dict | None = None, - reduced_dims: Sequence[str] | None = None + reduced_dims: Sequence[str] | None = None, ) -> xr.DataArray | xr.Dataset: r""" Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. @@ -608,7 +608,7 @@ def _resample_map(obj_chnk, dm, frq, rs_kws, fun, mp_kws): # Template. We are hoping that this takes a negligeable time as it is never loaded. template = obj_rechunked.resample(**{dim: freq}, **resample_kwargs).first() - if reduced_dims: # Removed reduced dims + if reduced_dims: # Removed reduced dims template = template.isel({d: 0 for d in reduced_dims}, drop=True) # New chunks along time : infer the number of elements resulting from the resampling of each chunk From b9d79b0d9984f951e1cf6f274a6607ef2847a20b Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 16 Aug 2024 12:37:29 -0400 Subject: [PATCH 10/25] Revert "[pre-commit.ci] auto fixes from pre-commit.com hooks" This reverts commit a7e5bde82a37f56ba06aee1f402bb3209299f33e. --- xclim/indices/generic.py | 28 ++++++++++++---------------- xclim/indices/helpers.py | 4 ++-- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index f89af8662..858c3bf08 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -403,12 +403,10 @@ def spell_mask( # Checks multivariate = False if var_reducer is not None: - if "variable" in data.dims and "variable" in thresh.dims: + if 'variable' in data.dims and 'variable' in thresh.dims: multivariate = True else: - raise ValueError( - "'var_reducer' was given but the data does not have 'variable' dimension." - ) + raise ValueError("'var_reducer' was given but the data does not have 'variable' dimension.") if weights is not None: if win_reducer != "mean": raise ValueError( @@ -494,7 +492,7 @@ def spell_length_statistics( freq : str Resampling frequency. resample_before_spells : bool - Determines if the resampling should take place before or after finding the + Determines if the resampling should take place before or after finding the spells. If it takes place before, spells cannot cross the period boundary. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. @@ -537,7 +535,7 @@ def spell_length_statistics( """ thresh = convert_units_to(threshold, data, context="infer") - def _spell(da, frq=None): + def _spell(da, frq = None): is_in_spell = spell_mask(da, window, win_reducer, op, thresh).astype(np.float32) is_in_spell = select_time(is_in_spell, **indexer) @@ -550,7 +548,7 @@ def _spell(da, frq=None): ) if resample_before_spells: - out = resample_map(data, "time", freq, _spell) + out = resample_map(data, 'time', freq, _spell) else: out = _spell(data, freq) @@ -601,7 +599,7 @@ def bivariate_spell_length_statistics( freq : str Resampling frequency. resample_before_spells : bool - Determines if the resampling should take place before or after finding the + Determines if the resampling should take place before or after finding the spells. If it takes place before, spells cannot cross the period boundary. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. @@ -616,16 +614,14 @@ def bivariate_spell_length_statistics( thresh1 = convert_units_to(threshold1, data1, context="infer") thresh2 = convert_units_to(threshold2, data2, context="infer") - data = xr.concat([data2, data2], "variable") + data = xr.concat([data2, data2], 'variable') if isinstance(thresh1, xr.DataArray): - thresh = xr.concat([thresh1, thresh2], "variable") + thresh = xr.concat([thresh1, thresh2], 'variable') else: - thresh = xr.DataArray([thresh1, thresh2], dims=("variable",)) + thresh = xr.DataArray([thresh1, thresh2], dims=('variable',)) - def _spell(da, frq=None): - is_in_spell = spell_mask( - da, window, win_reducer, op, thresh, var_reducer="all" - ).astype(np.float32) + def _spell(da, frq = None): + is_in_spell = spell_mask(da, window, win_reducer, op, thresh, var_reducer='all').astype(np.float32) is_in_spell = select_time(is_in_spell, **indexer) return rl.rle_statistics( @@ -637,7 +633,7 @@ def _spell(da, frq=None): ) if resample_before_spells: - out = resample_map(data, "time", freq, _spell, reduced_dims=["variable"]) + out = resample_map(data, 'time', freq, _spell, reduced_dims=['variable']) else: out = _spell(data, freq) diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 8b0d64de5..99b6cb8cc 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -552,7 +552,7 @@ def resample_map( map_blocks: bool | Literal["from_context"] = "from_context", resample_kwargs: dict | None = None, map_kwargs: dict | None = None, - reduced_dims: Sequence[str] | None = None, + reduced_dims: Sequence[str] | None = None ) -> xr.DataArray | xr.Dataset: r""" Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. @@ -608,7 +608,7 @@ def _resample_map(obj_chnk, dm, frq, rs_kws, fun, mp_kws): # Template. We are hoping that this takes a negligeable time as it is never loaded. template = obj_rechunked.resample(**{dim: freq}, **resample_kwargs).first() - if reduced_dims: # Removed reduced dims + if reduced_dims: # Removed reduced dims template = template.isel({d: 0 for d in reduced_dims}, drop=True) # New chunks along time : infer the number of elements resulting from the resampling of each chunk From b0fd634cbcf246c22b33b26d25f19d027442da2a Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 16 Aug 2024 12:37:46 -0400 Subject: [PATCH 11/25] Revert "Resample before spells" This reverts commit 9778739ccc60bd13ef2550fcab6722b66c0a38cf. --- xclim/indices/generic.py | 127 ++++++++++++++++++--------------------- xclim/indices/helpers.py | 5 -- 2 files changed, 57 insertions(+), 75 deletions(-) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 858c3bf08..2e748e770 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -359,13 +359,13 @@ def get_daily_events( def spell_mask( - data: xarray.DataArray, + data: xarray.DataArray | Sequence[xarray.DataArray], window: int, win_reducer: str, op: str, - thresh: float | xarray.DataArray, + thresh: float | Sequence[float], weights: Sequence[float] = None, - var_reducer: str | None = None, + var_reducer: str = "all", ) -> xarray.DataArray: """Compute the boolean mask of data points that are part of a spell as defined by a rolling statistic. @@ -373,40 +373,41 @@ def spell_mask( Parameters ---------- - data: DataArray - The input data. Considered multivariate if `var_reducer` is given and data has a "variable" dimension. - See `var_reducer` and `threshold` for the later case. + data: DataArray or sequence of DataArray + The input data. Can be a list, in which case the condition is checked on all variables. + See var_reducer for the latter case. window: int The length of the rolling window in which to compute statistics. win_reducer: {'min', 'max', 'sum', 'mean'} The statistics to compute on the rolling window. op: {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} The comparison operator to use when finding spells. - thresh: float or DataArray + thresh: float or sequence of floats The threshold to compare the rolling statistics against, as ``window_stats op threshold``. - If data is multivariate, this must be a DataArray with a "variable" dimension of the same length, with a threshold for each variable. - This function does not handle units. + If data is a list, this must be a list of the same length with a threshold for each variable. + This function does not handle units and can't accept Quantified objects. weights: sequence of floats A list of weights of the same length as the window. Only supported if `win_reducer` is "mean". - var_reducer: {'all', 'any'}, optional - If the data is multivariate, the condition must either be fulfilled on *all* + var_reducer: {'all', 'any'} + If the data is a list, the condition must either be fulfilled on *all* or *any* variables for the period to be considered a spell. - If None (default), the data is not considered multivariate. Returns ------- xarray.DataArray Same shape as ``data``, but boolean. - If ``data`` was multivariate, the variable dimension has been removed. + If ``data`` was a list, this is a DataArray of the same shape as the alignment of all variables. """ # Checks - multivariate = False - if var_reducer is not None: - if 'variable' in data.dims and 'variable' in thresh.dims: - multivariate = True - else: - raise ValueError("'var_reducer' was given but the data does not have 'variable' dimension.") + if not isinstance(data, xarray.DataArray): + # thus a sequence + if np.isscalar(thresh) or len(data) != len(thresh): + raise ValueError( + "When ``data`` is given as a list, ``threshold`` must be a sequence of the same length." + ) + data = xarray.concat(data, "variable") + thresh = xarray.DataArray(thresh, dims=("variable",)) if weights is not None: if win_reducer != "mean": raise ValueError( @@ -420,7 +421,7 @@ def spell_mask( if window == 1: # Fast path mask = compare(data, op, thresh) - if multivariate: + if not np.isscalar(thresh): is_in_spell = getattr(mask, var_reducer)("variable") elif (win_reducer == "min" and op in [">", ">=", "ge", "gt"]) or ( win_reducer == "max" and op in ["`<", "<=", "le", "lt"] @@ -428,7 +429,7 @@ def spell_mask( # Fast path for specific cases, this yields a smaller dask graph (rolling twice is expensive!) # For these two cases, a day can't be part of a spell if it doesn't respect the condition itself mask = compare(data, op, thresh) - if multivariate: + if not np.isscalar(thresh): mask = getattr(mask, var_reducer)("variable") # We need to filter out the spells shorter than "window" # find sequences of consecutive respected constraints @@ -448,7 +449,7 @@ def spell_mask( spell_value = getattr(data_pad.rolling(time=window), win_reducer)() # True at the end of a spell respecting the condition mask = compare(spell_value, op, thresh) - if multivariate: + if not np.isscalar(thresh): mask = getattr(mask, var_reducer)("variable") # True for all days part of a spell that respected the condition (shift because of the two rollings) is_in_spell = (mask.rolling(time=window).sum() >= 1).shift(time=-(window - 1)) @@ -466,7 +467,7 @@ def spell_length_statistics( op: str, spell_reducer: str, freq: str, - resample_before_spells: bool = False, + resample_before_rl: bool = True, **indexer, ): r"""Statistics on spells lengths. @@ -491,9 +492,9 @@ def spell_length_statistics( Statistic on the spell lengths. freq : str Resampling frequency. - resample_before_spells : bool - Determines if the resampling should take place before or after finding the - spells. If it takes place before, spells cannot cross the period boundary. + resample_before_rl : bool + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`. @@ -534,23 +535,18 @@ def spell_length_statistics( bivariate_spell_length_statistics : The bivariate version of this function. """ thresh = convert_units_to(threshold, data, context="infer") + is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) - def _spell(da, frq = None): - is_in_spell = spell_mask(da, window, win_reducer, op, thresh).astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - - return rl.rle_statistics( - is_in_spell, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=frq, - ) - - if resample_before_spells: - out = resample_map(data, 'time', freq, _spell) - else: - out = _spell(data, freq) + out = rl.resample_and_rl( + is_in_spell, + resample_before_rl, + rl.rle_statistics, + reducer=spell_reducer, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=freq, + ) if spell_reducer == "count": return out.assign_attrs(units="") @@ -569,7 +565,7 @@ def bivariate_spell_length_statistics( op: str, spell_reducer: str, freq: str, - resample_before_spells: bool = True, + resample_before_rl: bool = True, **indexer, ): r"""Statistics on spells lengths based on two variables. @@ -598,9 +594,9 @@ def bivariate_spell_length_statistics( Statistic on the spell lengths. freq : str Resampling frequency. - resample_before_spells : bool - Determines if the resampling should take place before or after finding the - spells. If it takes place before, spells cannot cross the period boundary. + resample_before_rl : bool + Determines if the resampling should take place before or after the run + length encoding (or a similar algorithm) is applied to runs. \*\*indexer Indexing parameters to compute the indicator on a temporal subset of the data. It accepts the same arguments as :py:func:`xclim.indices.generic.select_time`. @@ -613,34 +609,25 @@ def bivariate_spell_length_statistics( """ thresh1 = convert_units_to(threshold1, data1, context="infer") thresh2 = convert_units_to(threshold2, data2, context="infer") - - data = xr.concat([data2, data2], 'variable') - if isinstance(thresh1, xr.DataArray): - thresh = xr.concat([thresh1, thresh2], 'variable') - else: - thresh = xr.DataArray([thresh1, thresh2], dims=('variable',)) - - def _spell(da, frq = None): - is_in_spell = spell_mask(da, window, win_reducer, op, thresh, var_reducer='all').astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - - return rl.rle_statistics( - is_in_spell, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=frq, - ) - - if resample_before_spells: - out = resample_map(data, 'time', freq, _spell, reduced_dims=['variable']) - else: - out = _spell(data, freq) + is_in_spell = spell_mask( + [data1, data2], window, win_reducer, op, [thresh1, thresh2], var_reducer="all" + ).astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) + + out = rl.resample_and_rl( + is_in_spell, + resample_before_rl, + rl.rle_statistics, + reducer=spell_reducer, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=freq, + ) if spell_reducer == "count": return out.assign_attrs(units="") # All other cases are statistics of the number of timesteps - return to_agg_units(out, data, "count") + return to_agg_units(out, data1, "count") @declare_relative_units(thresh="") diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 99b6cb8cc..66019ac07 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -552,7 +552,6 @@ def resample_map( map_blocks: bool | Literal["from_context"] = "from_context", resample_kwargs: dict | None = None, map_kwargs: dict | None = None, - reduced_dims: Sequence[str] | None = None ) -> xr.DataArray | xr.Dataset: r""" Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. @@ -576,8 +575,6 @@ def resample_map( Other arguments to pass to `obj.resample()`. map_kwargs: dict, optional Arguments to pass to `map`. - reduced_dims: sequence of strings, optional - A list of dims on `obj` that will be reduced (removed) by the mapped function. Returns ------- @@ -608,8 +605,6 @@ def _resample_map(obj_chnk, dm, frq, rs_kws, fun, mp_kws): # Template. We are hoping that this takes a negligeable time as it is never loaded. template = obj_rechunked.resample(**{dim: freq}, **resample_kwargs).first() - if reduced_dims: # Removed reduced dims - template = template.isel({d: 0 for d in reduced_dims}, drop=True) # New chunks along time : infer the number of elements resulting from the resampling of each chunk if isinstance(obj_rechunked, xr.Dataset): From cef0e252b20021dbb16dfc9cf66d00e4e25741ee Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Thu, 5 Sep 2024 15:41:13 -0400 Subject: [PATCH 12/25] multi reducing --- tests/test_generic.py | 31 ++++++++++++ xclim/indices/generic.py | 107 +++++++++++++++++++++++++-------------- 2 files changed, 101 insertions(+), 37 deletions(-) diff --git a/tests/test_generic.py b/tests/test_generic.py index c716a7d86..0054dd935 100644 --- a/tests/test_generic.py +++ b/tests/test_generic.py @@ -737,3 +737,34 @@ def test_errors(self): # Weights must have same length as window with pytest.raises(ValueError, match="Weights have a different length"): generic.spell_mask(data, 3, "mean", "<=", 2, weights=[1, 2]) + + +def test_spell_length_statistics_multi(tasmin_series, tasmax_series): + tn = tasmin_series( + np.zeros( + 365, + ) + + 270, + start="2001-01-01", + ) + tx = tasmax_series( + np.zeros( + 365, + ) + + 270, + start="2001-01-01", + ) + + outc, outs, outm = generic.bivariate_spell_length_statistics( + tn, + "0 °C", + tx, + "1°C", + window=5, + win_reducer="min", + op="<", + spell_reducer=["count", "sum", "max"], + freq="YS", + ) + xr.testing.assert_equal(outs, outm) + np.testing.assert_allclose(outc, 1) diff --git a/xclim/indices/generic.py b/xclim/indices/generic.py index 13132c1c6..027b48a2f 100644 --- a/xclim/indices/generic.py +++ b/xclim/indices/generic.py @@ -407,7 +407,10 @@ def spell_mask( "When ``data`` is given as a list, ``threshold`` must be a sequence of the same length." ) data = xarray.concat(data, "variable") - thresh = xarray.DataArray(thresh, dims=("variable",)) + if isinstance(thresh[0], xarray.DataArray): + thresh = xr.concat(thresh, "variable") + else: + thresh = xarray.DataArray(thresh, dims=("variable",)) if weights is not None: if win_reducer != "mean": raise ValueError( @@ -458,6 +461,50 @@ def spell_mask( return is_in_spell +def _spell_length_statistics( + data: xarray.DataArray | Sequence[xarray.DataArray], + thresh: float | xarray.DataArray | Sequence[xarray.DataArray] | Sequence[float], + window: int, + win_reducer: str, + op: str, + spell_reducer: str | Sequence[str], + freq: str, + resample_before_rl: bool = True, + **indexer, +) -> xarray.DataArray | Sequence[xarray.DataArray]: + if isinstance(spell_reducer, str): + spell_reducer = [spell_reducer] + is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) + is_in_spell = select_time(is_in_spell, **indexer) + + outs = [] + for sr in spell_reducer: + out = rl.resample_and_rl( + is_in_spell, + resample_before_rl, + rl.rle_statistics, + reducer=sr, + # The code above already ensured only spell of the minimum length are selected + window=1, + freq=freq, + ) + + if sr == "count": + outs.append(out.assign_attrs(units="")) + else: + # All other cases are statistics of the number of timesteps + outs.append( + to_agg_units( + out, + data if isinstance(data, xarray.DataArray) else data[0], + "count", + ) + ) + if len(outs) == 1: + return outs[0] + return tuple(outs) + + @declare_relative_units(threshold="") def spell_length_statistics( data: xarray.DataArray, @@ -488,8 +535,8 @@ def spell_length_statistics( Note that this does not matter when `window` is 1. op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} Logical operator. Ex: spell_value > thresh. - spell_reducer : {'max', 'sum', 'count'} - Statistic on the spell lengths. + spell_reducer : {'max', 'sum', 'count'} or sequence thereof + Statistic on the spell lengths. If a list, multiple statistics are computed. freq : str Resampling frequency. resample_before_rl : bool @@ -535,24 +582,18 @@ def spell_length_statistics( bivariate_spell_length_statistics : The bivariate version of this function. """ thresh = convert_units_to(threshold, data, context="infer") - is_in_spell = spell_mask(data, window, win_reducer, op, thresh).astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - - out = rl.resample_and_rl( - is_in_spell, + return _spell_length_statistics( + data, + thresh, + window, + win_reducer, + op, + spell_reducer, + freq, resample_before_rl, - rl.rle_statistics, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=freq, + **indexer, ) - if spell_reducer == "count": - return out.assign_attrs(units="") - # All other cases are statistics of the number of timesteps - return to_agg_units(out, data, "count") - @declare_relative_units(threshold1="", threshold2="") def bivariate_spell_length_statistics( @@ -590,8 +631,8 @@ def bivariate_spell_length_statistics( Note that this does not matter when `window` is 1. op : {">", "gt", "<", "lt", ">=", "ge", "<=", "le", "==", "eq", "!=", "ne"} Logical operator. Ex: spell_value > thresh. - spell_reducer : {'max', 'sum', 'count'} - Statistic on the spell lengths. + spell_reducer : {'max', 'sum', 'count'} or sequence thereof + Statistic on the spell lengths. If a list, multiple statistics are computed. freq : str Resampling frequency. resample_before_rl : bool @@ -609,26 +650,18 @@ def bivariate_spell_length_statistics( """ thresh1 = convert_units_to(threshold1, data1, context="infer") thresh2 = convert_units_to(threshold2, data2, context="infer") - is_in_spell = spell_mask( - [data1, data2], window, win_reducer, op, [thresh1, thresh2], var_reducer="all" - ).astype(np.float32) - is_in_spell = select_time(is_in_spell, **indexer) - - out = rl.resample_and_rl( - is_in_spell, + return _spell_length_statistics( + [data1, data2], + [thresh1, thresh2], + window, + win_reducer, + op, + spell_reducer, + freq, resample_before_rl, - rl.rle_statistics, - reducer=spell_reducer, - # The code above already ensured only spell of the minimum length are selected - window=1, - freq=freq, + **indexer, ) - if spell_reducer == "count": - return out.assign_attrs(units="") - # All other cases are statistics of the number of timesteps - return to_agg_units(out, data1, "count") - @declare_relative_units(thresh="") def season( From 90c14efc4d777d3b05cf2b6436b3dc4a326e0cbb Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 6 Sep 2024 14:04:30 -0400 Subject: [PATCH 13/25] fix deps - add minimal tests --- pyproject.toml | 2 +- tests/test_indices.py | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6d4fb480b..6dfb96c28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["flit_core >=3.8,<4"] +requires = ["flit_core >=3.9,<4"] build-backend = "flit_core.buildapi" [project] diff --git a/tests/test_indices.py b/tests/test_indices.py index d1da9c54c..80698a681 100644 --- a/tests/test_indices.py +++ b/tests/test_indices.py @@ -1428,11 +1428,12 @@ def test_1d(self, tasmax_series, thresh, window, op, expected): def test_resampling_order(self, tasmax_series, resample_before_rl, expected): a = np.zeros(365) a[5:35] = 31 - tx = tasmax_series(a + K2C) + tx = tasmax_series(a + K2C).chunk() - hsf = xci.hot_spell_frequency( - tx, resample_before_rl=resample_before_rl, freq="MS" - ) + with set_options(resample_map_blocks=resample_before_rl): + hsf = xci.hot_spell_frequency( + tx, resample_before_rl=resample_before_rl, freq="MS" + ).load() assert hsf[1] == expected @@ -1708,10 +1709,11 @@ def test_run_start_at_0(self, pr_series): def test_resampling_order(self, pr_series, resample_before_rl, expected): a = np.zeros(365) + 10 a[5:35] = 0 - pr = pr_series(a) - out = xci.maximum_consecutive_dry_days( - pr, freq="ME", resample_before_rl=resample_before_rl - ) + pr = pr_series(a).chunk() + with set_options(resample_map_blocks=resample_before_rl): + out = xci.maximum_consecutive_dry_days( + pr, freq="ME", resample_before_rl=resample_before_rl + ).load() assert out[0] == expected From 03d236e29d9e9febfc66315cd3b753e698dc28e9 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 6 Sep 2024 14:08:33 -0400 Subject: [PATCH 14/25] add changelog --- CHANGELOG.rst | 1 + xclim/core/options.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index b28b0e176..78afcc572 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -14,6 +14,7 @@ New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ * New generic ``xclim.indices.generic.spell_mask`` that returns a mask of which days are part of a spell. Supports multivariate conditions and weights. Used in new generic index ``xclim.indices.generic.bivariate_spell_length_statistics`` that extends ``spell_length_statistics`` to two variables. (:pull:`1885`). * Indicator parameters can now be assigned a new name, different from the argument name in the compute function. (:pull:`1885`). +* New global option ``resample_map_blocks`` to wrap all ``resample().map()`` code inside a ``xr.map_blocks`` to lower the number of dask tasks. Uses utility ``xclim.indices.helpers.resample_map`` and requires ``flox`` to ensure the chunking allows such block-mapping. Defaults to False. (:pull:`1848`). Bug fixes ^^^^^^^^^ diff --git a/xclim/core/options.py b/xclim/core/options.py index d75e04c95..3437396d4 100644 --- a/xclim/core/options.py +++ b/xclim/core/options.py @@ -189,8 +189,8 @@ class set_options: as_dataset : bool If True, indicators output datasets. If False, they output DataArrays. Default :``False``. resample_map_blocks: bool - If True, some indicators will wrap their resampling operations with `xr.map_blocks`, using :py:func:`xclim.indices.utils.resample_map`. - This requires `flox` to be installed in order to ensure the chunking is appropriate. + If True, some indicators will wrap their resampling operations with `xr.map_blocks`, using :py:func:`xclim.indices.helpers.resample_map`. + This requires `flox` to be installed in order to ensure the chunking is appropriate.git Examples -------- From 1f3e82e5d7c3db10c9bcb632630b36ae98ecf579 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 6 Sep 2024 14:21:45 -0400 Subject: [PATCH 15/25] Dont test resample-map without flox --- tests/test_indices.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tests/test_indices.py b/tests/test_indices.py index 80698a681..f33ce3ef2 100644 --- a/tests/test_indices.py +++ b/tests/test_indices.py @@ -29,6 +29,11 @@ from xclim.core.options import set_options from xclim.core.units import convert_units_to, units +try: + import flox +except ImportError: + flox = None + K2C = 273.15 @@ -1430,7 +1435,9 @@ def test_resampling_order(self, tasmax_series, resample_before_rl, expected): a[5:35] = 31 tx = tasmax_series(a + K2C).chunk() - with set_options(resample_map_blocks=resample_before_rl): + with set_options( + resample_map_blocks=(resample_before_rl and (flox is not None)) + ): hsf = xci.hot_spell_frequency( tx, resample_before_rl=resample_before_rl, freq="MS" ).load() @@ -1710,7 +1717,9 @@ def test_resampling_order(self, pr_series, resample_before_rl, expected): a = np.zeros(365) + 10 a[5:35] = 0 pr = pr_series(a).chunk() - with set_options(resample_map_blocks=resample_before_rl): + with set_options( + resample_map_blocks=(resample_before_rl and (flox is not None)) + ): out = xci.maximum_consecutive_dry_days( pr, freq="ME", resample_before_rl=resample_before_rl ).load() From b1dd2acdfe39a8097c923e8bec5efe3a124bbe97 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 6 Sep 2024 14:35:13 -0400 Subject: [PATCH 16/25] Apply suggestions from code review Co-authored-by: Trevor James Smith <10819524+Zeitsperre@users.noreply.github.com> --- xclim/indices/_threshold.py | 2 +- xclim/indices/helpers.py | 17 +++++++++-------- 2 files changed, 10 insertions(+), 9 deletions(-) diff --git a/xclim/indices/_threshold.py b/xclim/indices/_threshold.py index 63b4a71cd..3c5326c20 100644 --- a/xclim/indices/_threshold.py +++ b/xclim/indices/_threshold.py @@ -30,7 +30,7 @@ threshold_count, ) -from .helpers import resample_map +from xclim.indices.helpers import resample_map # Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start # See http://pandas.pydata.org/pandas-docs/stable/timeseries.html#offset-aliases diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 428950137..236452adb 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -2,8 +2,8 @@ Indices Helper Functions Submodule ================================== -Functions that encapsulate logic but could be shared by many indices, -but are not index-like themselves (these should go in the :py:mod:`xclim.indices.generic` module). +Functions that encapsulate logic and can be shared by many indices, +but are not particularly index-like themselves (those should go in the :py:mod:`xclim.indices.generic` module). """ from __future__ import annotations @@ -568,22 +568,22 @@ def resample_map( Parameters ---------- - obj: DataArray or Dataset + obj : DataArray or Dataset The xarray object to resample. - dim: str + dim : str Dimension over which to resample. freq : str Resampling frequency along `dim`. func : callable Function to map on each resampled group. - map_blocks: bool or "from_context" + map_blocks : bool or "from_context" If True, the resample().map() call is wrapped inside a `map_blocks`. If False, this does not do anything special. If "from_context", xclim's "resample_map_blocks" option is used. If the object is not using dask, this is set to False. - resample_kwargs: dict, optional + resample_kwargs : dict, optional Other arguments to pass to `obj.resample()`. - map_kwargs: dict, optional + map_kwargs : dict, optional Arguments to pass to `map`. Returns @@ -601,7 +601,8 @@ def resample_map( try: from flox.xarray import rechunk_for_blockwise except ImportError as err: - raise ValueError(f"Using {MAP_BLOCKS}=True requires flox.") from err + msg = f"Using {MAP_BLOCKS}=True requires flox." + raise ValueError(msg) from err # Make labels, a unique integer for each resample group labels = xr.full_like(obj[dim], -1, dtype=np.int32) From 8b717b355ca7be69b54f6c8862bb30a0743a27ba Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 6 Sep 2024 18:37:04 +0000 Subject: [PATCH 17/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xclim/indices/_threshold.py | 1 - 1 file changed, 1 deletion(-) diff --git a/xclim/indices/_threshold.py b/xclim/indices/_threshold.py index 3c5326c20..6ddfef709 100644 --- a/xclim/indices/_threshold.py +++ b/xclim/indices/_threshold.py @@ -29,7 +29,6 @@ spell_length_statistics, threshold_count, ) - from xclim.indices.helpers import resample_map # Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start From 561c54aa849433cac777697ef87889e1c60c5b08 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Sep 2024 22:25:21 +0000 Subject: [PATCH 18/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- xclim/indices/helpers.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 1ff34a79b..cc1bc1471 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -8,9 +8,9 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Callable, Mapping from inspect import stack -from typing import Any, Callable, Literal, cast +from typing import Any, Literal, cast import cf_xarray # noqa: F401, pylint: disable=unused-import import cftime From 54e52343861e648192b100a1caeb82cfabdcc7ba Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 1 Oct 2024 16:29:32 -0400 Subject: [PATCH 19/25] Skip auxiliary coords test --- xclim/core/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xclim/core/utils.py b/xclim/core/utils.py index 8b2f1991b..2d5b3eba1 100644 --- a/xclim/core/utils.py +++ b/xclim/core/utils.py @@ -788,6 +788,7 @@ def split_auxiliary_coordinates( The auxiliary coordinates can be merged back with the dataset with :py:meth:`xarray.Dataset.assign_coords` or :py:meth:`xarray.DataArray.assign_coords`. + >>> # xdoctest: +SKIP >>> clean, aux = split_auxiliary_coordinates(ds) >>> merged = clean.assign_coords(da.coords) >>> merged.identical(ds) # True From ee2e352ffffd9d59ad7ef0f81f9c8b49c6b5fcfa Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 1 Oct 2024 18:04:37 -0400 Subject: [PATCH 20/25] add tests --- environment.yml | 1 + pyproject.toml | 4 ++-- tests/test_helpers.py | 45 +++++++++++++++++++++++++++++++++++++++++++ tests/test_indices.py | 34 ++++++++++++++++---------------- xclim/core/utils.py | 2 +- 5 files changed, 66 insertions(+), 20 deletions(-) diff --git a/environment.yml b/environment.yml index 74dce6f8b..e2367a7bb 100644 --- a/environment.yml +++ b/environment.yml @@ -11,6 +11,7 @@ dependencies: - click >=8.1 - dask >=2.6.0 - filelock >=3.14.0 + - flox >= 0.9 - jsonpickle >=3.1.0 - numba >=0.54.1 - numpy >=1.23.0 diff --git a/pyproject.toml b/pyproject.toml index d42db0b49..2f5dd1175 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -111,7 +111,7 @@ docs = [ "sphinxcontrib-bibtex", "sphinxcontrib-svg2pdfconverter[Cairosvg]" ] -extras = ["fastnanquantile >=0.0.2", "POT >=0.9.4"] +extras = ["fastnanquantile >=0.0.2", "flox >=0.9", "POT >=0.9.4"] all = ["xclim[dev]", "xclim[docs]", "xclim[extras]"] [project.scripts] @@ -180,7 +180,7 @@ pep621_dev_dependency_groups = ["all", "dev", "docs"] "POT" = "ot" [tool.deptry.per_rule_ignores] -DEP001 = ["SBCK", "flox"] +DEP001 = ["SBCK"] DEP002 = ["bottleneck", "h5netcdf", "pyarrow"] DEP004 = ["matplotlib", "pooch", "pytest", "pytest_socket"] diff --git a/tests/test_helpers.py b/tests/test_helpers.py index 699cb3b6b..cd0395c1f 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -5,8 +5,11 @@ import pytest import xarray as xr +from xclim.core.options import set_options from xclim.core.units import convert_units_to +from xclim.core.utils import uses_dask from xclim.indices import helpers +from xclim.testing.helpers import assert_lazy @pytest.mark.parametrize("method,rtol", [("spencer", 5e3), ("simple", 1e2)]) @@ -132,3 +135,45 @@ def test_cosine_of_solar_zenith_angle(): ] ) np.testing.assert_allclose(cza[:4, :], exp_cza, rtol=1e-3) + + +def _test_function(da, op, dim): + return getattr(da, op)(dim) + + +@pytest.mark.parametrize( + ["in_chunks", "exp_chunks"], [(60, 6 * (2,)), (30, 12 * (1,)), (-1, (12,))] +) +def test_resample_map(tas_series, in_chunks, exp_chunks): + pytest.importorskip("flox") + tas = tas_series(365 * [1]).chunk(time=in_chunks) + with assert_lazy: + out = helpers.resample_map( + tas, "time", "MS", lambda da: da.mean("time"), map_blocks=True + ) + assert out.chunks[0] == exp_chunks + out.load() # Trigger compute to see if it actually works + + +def test_resample_map_dataset(tas_series, pr_series): + pytest.importorskip("flox") + tas = tas_series(3 * 365 * [1], start="2000-01-01").chunk(time=365) + pr = pr_series(3 * 365 * [1], start="2000-01-01").chunk(time=365) + ds = xr.Dataset({"pr": pr, "tas": tas}) + with set_options(resample_map_blocks=True): + with assert_lazy: + out = helpers.resample_map( + ds, + "time", + "YS", + lambda da: da.mean("time"), + ) + assert out.chunks["time"] == (1, 1, 1) + out.load() + + +def test_resample_map_passthrough(tas_series): + tas = tas_series(365 * [1]) + with assert_lazy: + out = helpers.resample_map(tas, "time", "MS", lambda da: da.mean("time")) + assert not uses_dask(out) diff --git a/tests/test_indices.py b/tests/test_indices.py index f33ce3ef2..cc6aabf9b 100644 --- a/tests/test_indices.py +++ b/tests/test_indices.py @@ -29,11 +29,6 @@ from xclim.core.options import set_options from xclim.core.units import convert_units_to, units -try: - import flox -except ImportError: - flox = None - K2C = 273.15 @@ -1435,14 +1430,22 @@ def test_resampling_order(self, tasmax_series, resample_before_rl, expected): a[5:35] = 31 tx = tasmax_series(a + K2C).chunk() - with set_options( - resample_map_blocks=(resample_before_rl and (flox is not None)) - ): - hsf = xci.hot_spell_frequency( - tx, resample_before_rl=resample_before_rl, freq="MS" - ).load() + hsf = xci.hot_spell_frequency( + tx, resample_before_rl=resample_before_rl, freq="MS" + ).load() assert hsf[1] == expected + @pytest.importorskip("flox") + @pytest.mark.parametrize("resample_map", [True, False]) + def test_resampling_map(self, tasmax_series, resample_map): + a = np.zeros(365) + a[5:35] = 31 + tx = tasmax_series(a + K2C).chunk() + + with set_options(resample_map_blocks=resample_map): + hsf = xci.hot_spell_frequency(tx, resample_before_rl=True, freq="MS").load() + assert hsf[1] == 1 + class TestHotSpellMaxLength: @pytest.mark.parametrize( @@ -1717,12 +1720,9 @@ def test_resampling_order(self, pr_series, resample_before_rl, expected): a = np.zeros(365) + 10 a[5:35] = 0 pr = pr_series(a).chunk() - with set_options( - resample_map_blocks=(resample_before_rl and (flox is not None)) - ): - out = xci.maximum_consecutive_dry_days( - pr, freq="ME", resample_before_rl=resample_before_rl - ).load() + out = xci.maximum_consecutive_dry_days( + pr, freq="ME", resample_before_rl=resample_before_rl + ).load() assert out[0] == expected diff --git a/xclim/core/utils.py b/xclim/core/utils.py index 2d5b3eba1..16e24dab7 100644 --- a/xclim/core/utils.py +++ b/xclim/core/utils.py @@ -775,7 +775,7 @@ def split_auxiliary_coordinates( Returns ------- - clean_obj : + clean_obj : DataArray or Dataset Same as `obj` but without any auxiliary coordinate. aux_coords : Dataset The auxiliary coordinates as a dataset. Might be empty. From 6fcd4a93f60bfeffe77c7dc4f6db74ff7a88f7f8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 3 Oct 2024 19:48:12 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/test_helpers.py | 2 +- xclim/indices/helpers.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_helpers.py b/tests/test_helpers.py index dc253ec5a..8ac4e262f 100644 --- a/tests/test_helpers.py +++ b/tests/test_helpers.py @@ -178,7 +178,7 @@ def test_resample_map_passthrough(tas_series): out = helpers.resample_map(tas, "time", "MS", lambda da: da.mean("time")) assert not uses_dask(out) - + @pytest.mark.parametrize("cftime", [False, True]) def test_make_hourly_temperature(tasmax_series, tasmin_series, cftime): tasmax = tasmax_series(np.array([20]), units="degC", cftime=cftime) diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index 648b9905e..c3ef99a19 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -563,7 +563,7 @@ def _gather_lon(da: xr.DataArray) -> xr.DataArray: ) raise ValueError(msg) from err - + def resample_map( obj: xr.DataArray | xr.Dataset, dim: str, @@ -643,7 +643,7 @@ def _resample_map(obj_chnk, dm, frq, rs_kws, fun, mp_kws): _resample_map, (dim, freq, resample_kwargs, func, map_kwargs), template=template ) - + def _compute_daytime_temperature( hour_after_sunrise: xr.DataArray, tasmin: xr.DataArray, From 50ffcff7c611f3b6351076fa6ee71330974f754d Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 4 Oct 2024 16:20:38 -0400 Subject: [PATCH 22/25] Import callable --- xclim/indices/helpers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index c3ef99a19..b70c063d3 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -8,7 +8,7 @@ from __future__ import annotations -from collections.abc import Mapping +from collections.abc import Callable, Mapping from datetime import timedelta from inspect import stack from typing import Any, Literal, cast From d5d638a95058300ed9fb142d6e3d87cdf1fafb46 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Fri, 4 Oct 2024 16:38:09 -0400 Subject: [PATCH 23/25] fix test --- tests/test_indices.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_indices.py b/tests/test_indices.py index eb46e4adb..8b5a1d4df 100644 --- a/tests/test_indices.py +++ b/tests/test_indices.py @@ -1461,9 +1461,9 @@ def test_resampling_order(self, tasmax_series, resample_before_rl, expected): ).load() assert hsf[1] == expected - @pytest.importorskip("flox") @pytest.mark.parametrize("resample_map", [True, False]) def test_resampling_map(self, tasmax_series, resample_map): + pytest.importorskip("flox") a = np.zeros(365) a[5:35] = 31 tx = tasmax_series(a + K2C).chunk() From bb47d2b1e01f3f02ad13f961198ca5660b02bbf4 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Tue, 8 Oct 2024 09:50:51 -0400 Subject: [PATCH 24/25] Resample map for chill portions --- tests/test_atmos.py | 12 ++++++++++-- xclim/indices/_agro.py | 15 +++++++-------- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/tests/test_atmos.py b/tests/test_atmos.py index 80757e5f7..aca53ada5 100644 --- a/tests/test_atmos.py +++ b/tests/test_atmos.py @@ -3,6 +3,7 @@ from __future__ import annotations import numpy as np +import pytest import xarray as xr from xclim import atmos, set_options @@ -642,11 +643,18 @@ def test_chill_units(atmosds): np.testing.assert_allclose(cu.isel(location=0), exp, rtol=1e-03) -def test_chill_portions(atmosds): +@pytest.mark.parametrize("use_dask", [True, False]) +def test_chill_portions(atmosds, use_dask): + pytest.importorskip("flox") tasmax = atmosds.tasmax tasmin = atmosds.tasmin tas = make_hourly_temperature(tasmin, tasmax) - cp = atmos.chill_portions(tas, date_bounds=("09-01", "03-30"), freq="YS-JUL") + if use_dask: + tas = tas.chunk(time=tas.time.size // 2, location=1) + + with set_options(resample_map_blocks=True): + cp = atmos.chill_portions(tas, date_bounds=("09-01", "03-30"), freq="YS-JUL") + assert cp.attrs["units"] == "1" assert cp.name == "cp" # Although its 4 years of data its 5 seasons starting in July diff --git a/xclim/indices/_agro.py b/xclim/indices/_agro.py index d26b0a918..719df0357 100644 --- a/xclim/indices/_agro.py +++ b/xclim/indices/_agro.py @@ -16,6 +16,7 @@ rate2amount, to_agg_units, ) +from xclim.core.utils import uses_dask from xclim.indices._conversion import potential_evapotranspiration from xclim.indices._simple import tn_min from xclim.indices._threshold import ( @@ -23,7 +24,7 @@ first_day_temperature_below, ) from xclim.indices.generic import aggregate_between_dates, get_zones -from xclim.indices.helpers import _gather_lat, day_lengths +from xclim.indices.helpers import _gather_lat, day_lengths, resample_map from xclim.indices.stats import standardized_index # Frequencies : YS: year start, QS-DEC: seasons starting in december, MS: month start @@ -1564,7 +1565,8 @@ def _chill_portion_one_season(tas_K): def _apply_chill_portion_one_season(tas_K): """Apply the chill portion function on to an xarray DataArray.""" - tas_K = tas_K.chunk(time=-1) + if uses_dask(tas_K): + tas_K = tas_K.chunk(time=-1) return xarray.apply_ufunc( _chill_portion_one_season, tas_K, @@ -1627,12 +1629,9 @@ def chill_portions( tas_K: xarray.DataArray = select_time( convert_units_to(tas, "K"), drop=True, **indexer ) - # TODO: use resample_map once #1848 is merged - return ( - tas_K.resample(time=freq) - .map(_apply_chill_portion_one_season) - .assign_attrs(units="") - ) + return resample_map( + tas_K, "time", freq, _apply_chill_portion_one_season + ).assign_attrs(units="") @declare_units(tas="[temperature]") From 57da57571747b43903b9eec83a44c50578f46fd3 Mon Sep 17 00:00:00 2001 From: Pascal Bourgault Date: Wed, 9 Oct 2024 15:11:10 -0400 Subject: [PATCH 25/25] Fix docstring --- xclim/indices/helpers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/xclim/indices/helpers.py b/xclim/indices/helpers.py index b70c063d3..25986cefd 100644 --- a/xclim/indices/helpers.py +++ b/xclim/indices/helpers.py @@ -573,8 +573,7 @@ def resample_map( resample_kwargs: dict | None = None, map_kwargs: dict | None = None, ) -> xr.DataArray | xr.Dataset: - r""" - Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. + r"""Wraps xarray's resample(...).map() with a :py:func:`xarray.map_blocks`, ensuring the chunking is appropriate using flox. Parameters ----------