-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy path_percentile.py
96 lines (76 loc) · 2.76 KB
/
_percentile.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
from typing import List, Sequence
import dask.array as da
import xarray as xr
import numpy as np
from dask.base import tokenize
import dask
from functools import partial
def keep_good_np(xx, where, nodata, out=None):
if out is None:
out = np.full_like(xx, nodata)
else:
assert out.shape == xx.shape
assert out.dtype == xx.dtype
assert out is not xx
out[:] = nodata
np.copyto(out, xx, where=where)
return out
def np_percentile(xx, percentile, nodata):
if np.isnan(nodata):
high = True
mask = ~np.isnan(xx)
else:
high = nodata >= xx.max()
mask = xx != nodata
valid_counts = mask.sum(axis=0)
xx = np.sort(xx, axis=0)
indices = np.round(percentile * (valid_counts - 1))
if not high:
indices += xx.shape[0] - valid_counts
indices[valid_counts == 0] = 0
indices = indices.astype(np.int64).flatten()
step = xx.size // xx.shape[0]
indices = step * indices + np.arange(len(indices))
xx = xx.take(indices).reshape(xx.shape[1:])
return keep_good_np(xx, (valid_counts >= 3), nodata)
def xr_quantile(
src: xr.DataArray,
quantiles: Sequence,
nodata,
) -> xr.DataArray:
"""
Calculates the percentiles of the input data along the time dimension.
This approach is approximately 700x faster than the `numpy` and `xarray` nanpercentile functions.
:param src: xr.Dataset, bands can be either
float or integer with `nodata` values to indicate gaps in data.
`nodata` must be the largest or smallest values in the dataset or NaN.
:param percentiles: A sequence of quantiles in the [0.0, 1.0] range
:param nodata: The `nodata` value
"""
data_vars={}
xx_data = src.data
out_dims = ("quantile",) + src.dims[1:]
# if dask.is_dask_collection(xx_data):
# xx_data = xx_data.rechunk({'time': -1})
tk = tokenize(xx_data, quantiles, nodata)
data = []
for quantile in quantiles:
name = f"pc_{int(100 * quantile)}"
if dask.is_dask_collection(xx_data):
yy = da.map_blocks(
partial(np_percentile, percentile=quantile, nodata=nodata),
xx_data,
drop_axis=0,
meta=np.array([], dtype=src.dtype),
name=f"{name}-{tk}",
)
else:
yy = np_percentile(xx_data, percentile=quantile, nodata=nodata)
data.append(yy)
if dask.is_dask_collection(yy):
data_vars['band'] = (out_dims, da.stack(data, axis=0))
else:
data_vars['band'] = (out_dims, np.stack(data, axis=0))
coords = dict((dim, src.coords[dim]) for dim in src.dims[1:])
coords["quantile"] = np.array(quantiles)
return xr.Dataset(data_vars=data_vars, coords=coords, attrs=src.attrs)