Skip to content

Commit

Permalink
weighted: small improvements (#4818)
Browse files Browse the repository at this point in the history
* weighted: small improvements

* use T_DataWithCoords
  • Loading branch information
mathause authored Jan 27, 2021
1 parent a4bb7e1 commit 9fea799
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
11 changes: 10 additions & 1 deletion xarray/core/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from html import escape
from textwrap import dedent
from typing import (
TYPE_CHECKING,
Any,
Callable,
Dict,
Expand Down Expand Up @@ -32,6 +33,12 @@
ALL_DIMS = ...


if TYPE_CHECKING:
from .dataarray import DataArray
from .weighted import Weighted

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")

C = TypeVar("C")
T = TypeVar("T")

Expand Down Expand Up @@ -772,7 +779,9 @@ def groupby_bins(
},
)

def weighted(self, weights):
def weighted(
self: T_DataWithCoords, weights: "DataArray"
) -> "Weighted[T_DataWithCoords]":
"""
Weighted operations.
Expand Down
49 changes: 18 additions & 31 deletions xarray/core/weighted.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import TYPE_CHECKING, Hashable, Iterable, Optional, Union, overload
from typing import TYPE_CHECKING, Generic, Hashable, Iterable, Optional, TypeVar, Union

from . import duck_array_ops
from .computation import dot
from .options import _get_keep_attrs
from .pycompat import is_duck_dask_array

if TYPE_CHECKING:
from .common import DataWithCoords # noqa: F401
from .dataarray import DataArray, Dataset

T_DataWithCoords = TypeVar("T_DataWithCoords", bound="DataWithCoords")


_WEIGHTED_REDUCE_DOCSTRING_TEMPLATE = """
Reduce this {cls}'s data by a weighted ``{fcn}`` along some dimension(s).
Expand Down Expand Up @@ -56,7 +59,7 @@
"""


class Weighted:
class Weighted(Generic[T_DataWithCoords]):
"""An object that implements weighted operations.
You should create a Weighted object by using the ``DataArray.weighted`` or
Expand All @@ -70,15 +73,7 @@ class Weighted:

__slots__ = ("obj", "weights")

@overload
def __init__(self, obj: "DataArray", weights: "DataArray") -> None:
...

@overload
def __init__(self, obj: "Dataset", weights: "DataArray") -> None:
...

def __init__(self, obj, weights):
def __init__(self, obj: T_DataWithCoords, weights: "DataArray"):
"""
Create a Weighted object
Expand Down Expand Up @@ -121,8 +116,8 @@ def _weight_check(w):
else:
_weight_check(weights.data)

self.obj = obj
self.weights = weights
self.obj: T_DataWithCoords = obj
self.weights: "DataArray" = weights

@staticmethod
def _reduce(
Expand All @@ -146,7 +141,6 @@ def _reduce(

# `dot` does not broadcast arrays, so this avoids creating a large
# DataArray (if `weights` has additional dimensions)
# maybe add fasttrack (`(da * weights).sum(dims=dim, skipna=skipna)`)
return dot(da, weights, dims=dim)

def _sum_of_weights(
Expand Down Expand Up @@ -203,7 +197,7 @@ def sum_of_weights(
self,
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
) -> T_DataWithCoords:

return self._implementation(
self._sum_of_weights, dim=dim, keep_attrs=keep_attrs
Expand All @@ -214,7 +208,7 @@ def sum(
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
) -> T_DataWithCoords:

return self._implementation(
self._weighted_sum, dim=dim, skipna=skipna, keep_attrs=keep_attrs
Expand All @@ -225,7 +219,7 @@ def mean(
dim: Optional[Union[Hashable, Iterable[Hashable]]] = None,
skipna: Optional[bool] = None,
keep_attrs: Optional[bool] = None,
) -> Union["DataArray", "Dataset"]:
) -> T_DataWithCoords:

return self._implementation(
self._weighted_mean, dim=dim, skipna=skipna, keep_attrs=keep_attrs
Expand All @@ -239,22 +233,15 @@ def __repr__(self):
return f"{klass} with weights along dimensions: {weight_dims}"


class DataArrayWeighted(Weighted):
def _implementation(self, func, dim, **kwargs):

keep_attrs = kwargs.pop("keep_attrs")
if keep_attrs is None:
keep_attrs = _get_keep_attrs(default=False)

weighted = func(self.obj, dim=dim, **kwargs)

if keep_attrs:
weighted.attrs = self.obj.attrs
class DataArrayWeighted(Weighted["DataArray"]):
def _implementation(self, func, dim, **kwargs) -> "DataArray":

return weighted
dataset = self.obj._to_temp_dataset()
dataset = dataset.map(func, dim=dim, **kwargs)
return self.obj._from_temp_dataset(dataset)


class DatasetWeighted(Weighted):
class DatasetWeighted(Weighted["Dataset"]):
def _implementation(self, func, dim, **kwargs) -> "Dataset":

return self.obj.map(func, dim=dim, **kwargs)
Expand Down

0 comments on commit 9fea799

Please sign in to comment.