Skip to content

Commit

Permalink
Update unweighted temporal averages to not require bounds (#579)
Browse files Browse the repository at this point in the history
- Time bounds are required for generating weights for weighted averages

Co-authored-by: tomvothecoder <[email protected]>
  • Loading branch information
tomvothecoder and tomvothecoder authored Jan 2, 2024
1 parent 6238148 commit fb624bd
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions xcdat/temporal.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,21 +757,17 @@ def _averager(
# Preprocess the dataset based on method argument values.
ds = self._preprocess_dataset(ds)

# Get the data variable and the required time axis metadata.
dv = _get_data_var(ds, data_var)
time_bounds = ds.bounds.get_bounds("T", var_key=dv.name)

if self._mode == "average":
dv = self._average(dv, time_bounds)
dv_avg = self._average(ds, data_var)
elif self._mode in ["group_average", "climatology", "departures"]:
dv = self._group_average(dv, time_bounds)
dv_avg = self._group_average(ds, data_var)

# The original time dimension is dropped from the dataset because
# it becomes obsolete after the data variable is averaged. When the
# averaged data variable is added to the dataset, the new time dimension
# and its associated coordinates are also added.
ds = ds.drop_dims(self.dim) # type: ignore
ds[dv.name] = dv
ds[dv_avg.name] = dv_avg

if keep_weights:
ds = self._keep_weights(ds)
Expand Down Expand Up @@ -1075,28 +1071,28 @@ def _drop_leap_days(self, ds: xr.Dataset):
)
return ds

def _average(
self, data_var: xr.DataArray, time_bounds: xr.DataArray
) -> xr.DataArray:
def _average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
"""Averages a data variable with the time dimension removed.
Parameters
----------
data_var : xr.DataArray
The data variable.
time_bounds : xr.DataArray
The time bounds.
ds : xr.Dataset
The dataset.
data_var : str
The key of the data variable.
Returns
-------
xr.DataArray
The averages for a data variable with the time dimension removed.
The data variable averaged with the time dimension removed.
"""
dv = data_var.copy()
dv = _get_data_var(ds, data_var)

with xr.set_options(keep_attrs=True):
if self._weighted:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

dv = dv.weighted(self._weights).mean(dim=self.dim) # type: ignore
else:
dv = dv.mean(dim=self.dim) # type: ignore
Expand All @@ -1105,31 +1101,31 @@ def _average(

return dv

def _group_average(
self, data_var: xr.DataArray, time_bounds: xr.DataArray
) -> xr.DataArray:
def _group_average(self, ds: xr.Dataset, data_var: str) -> xr.DataArray:
"""Averages a data variable by time group.
Parameters
----------
data_var : xr.DataArray
The data variable.
time_bounds : xr.DataArray
The time bounds.
ds : xr.Dataset
The dataset.
data_var : str
The key of the data variable.
Returns
-------
xr.DataArray
The data variable averaged by time group.
"""
dv = data_var.copy()
dv = _get_data_var(ds, data_var)

# Label the time coordinates for grouping weights and the data variable
# values.
self._labeled_time = self._label_time_coords(dv[self.dim])

if self._weighted:
time_bounds = ds.bounds.get_bounds("T", var_key=data_var)
self._weights = self._get_weights(time_bounds)

# Weight the data variable.
dv *= self._weights

Expand All @@ -1145,8 +1141,9 @@ def _group_average(
# included to take into account zero weight for missing data.
with xr.set_options(keep_attrs=True):
dv = self._group_data(dv).sum() / self._group_data(weights).sum()

# Restore the data variable's name.
dv.name = data_var.name
dv.name = data_var
else:
dv = self._group_data(dv).mean()

Expand Down

0 comments on commit fb624bd

Please sign in to comment.