diff --git a/xcdat/bounds.py b/xcdat/bounds.py index dfbc0aff..61e2bbce 100644 --- a/xcdat/bounds.py +++ b/xcdat/bounds.py @@ -236,18 +236,7 @@ def get_bounds( else: # Get the obj in the Dataset using the key. obj = _get_data_var(self._dataset, key=var_key) - - # Check if the object is a data variable or a coordinate variable. - # If it is a data variable, derive the axis coordinate variable. - if obj.name in list(self._dataset.data_vars): - coord = get_dim_coords(obj, axis) - elif obj.name in list(self._dataset.coords): - coord = obj - - try: - bounds_keys = [coord.attrs["bounds"]] - except KeyError: - bounds_keys = [] + bounds_keys = self._get_bounds_from_attr(obj, axis) if len(bounds_keys) == 0: raise KeyError( @@ -505,8 +494,32 @@ def _get_bounds_keys(self, axis: CFAxisKey) -> List[str]: except KeyError: pass + keys_from_attr = self._get_bounds_from_attr(self._dataset, axis) + keys = keys + keys_from_attr + return list(set(keys)) + def _get_bounds_from_attr( + self, obj: xr.DataArray | xr.Dataset, axis: CFAxisKey + ) -> List[str]: + # Check if the object is a data variable or a coordinate variable. + # If it is a data variable, derive the axis coordinate variable. + bounds_keys = [] + coords = get_dim_coords(obj, axis) + + if isinstance(coords, xr.DataArray): + bnds_key = coords.attrs.get("bounds") + if bnds_key is not None: + bounds_keys.append(bnds_key) + elif isinstance(coords, xr.Dataset): + for coord in coords.values(): + bnds_key = coord.attrs.get("bounds") + + if bnds_key is not None: + bounds_keys.append(bnds_key) + + return bounds_keys + def _create_time_bounds( # noqa: C901 self, time: xr.DataArray,