Skip to content

Commit

Permalink
Update get_bounds() to support mappable non-cf axis
Browse files Browse the repository at this point in the history
  • Loading branch information
tomvothecoder committed Oct 7, 2024
1 parent d9a140a commit e59e987
Showing 1 changed file with 25 additions and 12 deletions.
37 changes: 25 additions & 12 deletions xcdat/bounds.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,18 +236,7 @@ def get_bounds(
else:
# Get the obj in the Dataset using the key.
obj = _get_data_var(self._dataset, key=var_key)

# Check if the object is a data variable or a coordinate variable.
# If it is a data variable, derive the axis coordinate variable.
if obj.name in list(self._dataset.data_vars):
coord = get_dim_coords(obj, axis)
elif obj.name in list(self._dataset.coords):
coord = obj

try:
bounds_keys = [coord.attrs["bounds"]]
except KeyError:
bounds_keys = []
bounds_keys = self._get_bounds_from_attr(obj, axis)

if len(bounds_keys) == 0:
raise KeyError(
Expand Down Expand Up @@ -505,8 +494,32 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]:
except KeyError:
pass

keys_from_attr = self._get_bounds_from_attr(self._dataset, axis)
keys = keys + keys_from_attr

return list(set(keys))

def _get_bounds_from_attr(
self, obj: xr.DataArray | xr.Dataset, axis: CFAxisKey
) -> List[str]:
# Check if the object is a data variable or a coordinate variable.
# If it is a data variable, derive the axis coordinate variable.
bounds_keys = []
coords = get_dim_coords(obj, axis)

if isinstance(coords, xr.DataArray):
bnds_key = coords.attrs.get("bounds")
if bnds_key is not None:
bounds_keys.append(bnds_key)
elif isinstance(coords, xr.Dataset):
for coord in coords.values():
bnds_key = coord.attrs.get("bounds")

if bnds_key is not None:
bounds_keys.append(bnds_key)

return bounds_keys

def _create_time_bounds( # noqa: C901
self,
time: xr.DataArray,
Expand Down

0 comments on commit e59e987

Please sign in to comment.