diff --git a/setup.py b/setup.py index e01547e0..68b13c4d 100644 --- a/setup.py +++ b/setup.py @@ -86,7 +86,7 @@ def get_requirements(requirements_filename): "peter.marinescu@colostate.edu", ], license="BSD-3-Clause License", - packages=[PACKAGE_NAME, PACKAGE_NAME + ".utils"], + packages=[PACKAGE_NAME, PACKAGE_NAME + ".utils", PACKAGE_NAME + ".utils.internal"], install_requires=get_requirements("requirements.txt"), test_requires=["pytest"], zip_safe=False, diff --git a/tobac/utils/internal/__init__.py b/tobac/utils/internal/__init__.py new file mode 100644 index 00000000..da0f2b60 --- /dev/null +++ b/tobac/utils/internal/__init__.py @@ -0,0 +1 @@ +from .basic import * diff --git a/tobac/utils/internal.py b/tobac/utils/internal/basic.py similarity index 78% rename from tobac/utils/internal.py rename to tobac/utils/internal/basic.py index 703cf6b6..97df4dd5 100644 --- a/tobac/utils/internal.py +++ b/tobac/utils/internal/basic.py @@ -1,10 +1,25 @@ """Internal tobac utilities """ +from __future__ import annotations + import numpy as np import skimage.measure import xarray as xr import iris +import iris.cube +import pandas as pd import warnings +from . import iris_utils +from . import xarray_utils as xr_utils +from typing import Union, Callable + +# list of common vertical coordinates to search for in various functions +COMMON_VERT_COORDS: list[str] = [ + "z", + "model_level_number", + "altitude", + "geopotential_height", +] def _warn_auto_coordinate(): @@ -17,7 +32,7 @@ def _warn_auto_coordinate(): ) -def get_label_props_in_dict(labels): +def get_label_props_in_dict(labels: np.array) -> dict: """Function to get the label properties into a dictionary format. Parameters @@ -40,7 +55,7 @@ def get_label_props_in_dict(labels): return region_properties_dict -def get_indices_of_labels_from_reg_prop_dict(region_property_dict): +def get_indices_of_labels_from_reg_prop_dict(region_property_dict: dict) -> tuple[dict]: """Function to get the x, y, and z indices (as well as point count) of all labeled regions. Parameters ---------- @@ -94,7 +109,7 @@ def get_indices_of_labels_from_reg_prop_dict(region_property_dict): return [curr_loc_indices, y_indices, x_indices] -def iris_to_xarray(func): +def iris_to_xarray(func: Callable) -> Callable: """Decorator that converts all input of a function that is in the form of Iris cubes into xarray DataArrays and converts all outputs with type xarray DataArrays back into Iris cubes. @@ -164,7 +179,7 @@ def wrapper(*args, **kwargs): return wrapper -def xarray_to_iris(func): +def xarray_to_iris(func: Callable) -> Callable: """Decorator that converts all input of a function that is in the form of xarray DataArrays into Iris cubes and converts all outputs with type Iris cubes back into xarray DataArrays. @@ -248,7 +263,7 @@ def wrapper(*args, **kwargs): return wrapper -def irispandas_to_xarray(func): +def irispandas_to_xarray(func: Callable) -> Callable: """Decorator that converts all input of a function that is in the form of Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and converts all outputs with the type xarray DataArray/xarray Dataset @@ -328,7 +343,7 @@ def wrapper(*args, **kwargs): return wrapper -def xarray_to_irispandas(func): +def xarray_to_irispandas(func: Callable) -> Callable: """Decorator that converts all input of a function that is in the form of DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and converts all outputs with the type Iris cubes/pandas Dataframes back into @@ -431,7 +446,7 @@ def wrapper(*args, **kwargs): return wrapper -def njit_if_available(func, **kwargs): +def njit_if_available(func: Callable, **kwargs) -> Callable: """Decorator to wrap a function with numba.njit if available. If numba isn't available, it just returns the function. @@ -456,12 +471,15 @@ def njit_if_available(func, **kwargs): return func -def find_vertical_axis_from_coord(variable_cube, vertical_coord=None): +def find_vertical_axis_from_coord( + variable_cube: Union[iris.cube.Cube, xr.DataArray], + vertical_coord: Union[str, None] = None, +) -> str: """Function to find the vertical coordinate in the iris cube Parameters ---------- - variable_cube: iris.cube + variable_cube: iris.cube.Cube or xarray.DataArray Input variable cube, containing a vertical coordinate. vertical_coord: str Vertical coordinate name. If None, this function tries to auto-detect. @@ -476,69 +494,21 @@ def find_vertical_axis_from_coord(variable_cube, vertical_coord=None): ValueError Raised if the vertical coordinate isn't found in the cube. """ - list_vertical = [ - "z", - "model_level_number", - "altitude", - "geopotential_height", - ] if vertical_coord == "auto": _warn_auto_coordinate() if isinstance(variable_cube, iris.cube.Cube): - list_coord_names = [coord.name() for coord in variable_cube.coords()] - elif isinstance(variable_cube, xr.Dataset) or isinstance( - variable_cube, xr.DataArray - ): - list_coord_names = variable_cube.coords - - if vertical_coord is None or vertical_coord == "auto": - # find the intersection - all_vertical_axes = list(set(list_coord_names) & set(list_vertical)) - if len(all_vertical_axes) >= 1: - return all_vertical_axes[0] - else: - raise ValueError( - "Cube lacks suitable automatic vertical coordinate (z, model_level_number, altitude, or geopotential_height)" - ) - elif vertical_coord in list_coord_names: - return vertical_coord - else: - raise ValueError("Please specify vertical coordinate found in cube") - - -def find_axis_from_coord(variable_cube, coord_name): - """Finds the axis number in an iris cube given a coordinate name. - - Parameters - ---------- - variable_cube: iris.cube - Input variable cube - coord_name: str - coordinate to look for + return iris_utils.find_vertical_axis_from_coord(variable_cube, vertical_coord) + if isinstance(variable_cube, xr.Dataset) or isinstance(variable_cube, xr.DataArray): + return xr_utils.find_vertical_axis_from_coord(variable_cube, vertical_coord) - Returns - ------- - axis_number: int - the number of the axis of the given coordinate, or None if the coordinate - is not found in the cube or not a dimensional coordinate - """ - - list_coord_names = [coord.name() for coord in variable_cube.coords()] - all_matching_axes = list(set(list_coord_names) & set((coord_name,))) - if ( - len(all_matching_axes) == 1 - and len(variable_cube.coord_dims(all_matching_axes[0])) > 0 - ): - return variable_cube.coord_dims(all_matching_axes[0])[0] - elif len(all_matching_axes) > 1: - raise ValueError("Too many axes matched.") - else: - return None + raise ValueError("variable_cube must be xr.DataArray or iris.cube.Cube") -def find_dataframe_vertical_coord(variable_dataframe, vertical_coord=None): +def find_dataframe_vertical_coord( + variable_dataframe: pd.DataFrame, vertical_coord: Union[str, None] = None +) -> str: """Function to find the vertical coordinate in the iris cube Parameters @@ -563,8 +533,9 @@ def find_dataframe_vertical_coord(variable_dataframe, vertical_coord=None): _warn_auto_coordinate() if vertical_coord is None or vertical_coord == "auto": - list_vertical = ["z", "model_level_number", "altitude", "geopotential_height"] - all_vertical_axes = list(set(variable_dataframe.columns) & set(list_vertical)) + all_vertical_axes = list( + set(variable_dataframe.columns) & set(COMMON_VERT_COORDS) + ) if len(all_vertical_axes) == 1: return all_vertical_axes[0] else: @@ -578,7 +549,7 @@ def find_dataframe_vertical_coord(variable_dataframe, vertical_coord=None): @njit_if_available -def calc_distance_coords(coords_1, coords_2): +def calc_distance_coords(coords_1: np.array, coords_2: np.array) -> float: """Function to calculate the distance between cartesian coordinate set 1 and coordinate set 2. Parameters @@ -605,13 +576,17 @@ def calc_distance_coords(coords_1, coords_2): return np.sqrt(np.sum(deltas**2)) -def find_hdim_axes_3D(field_in, vertical_coord=None, vertical_axis=None): +def find_hdim_axes_3D( + field_in: Union[iris.cube.Cube, xr.DataArray], + vertical_coord: Union[str, None] = None, + vertical_axis: Union[int, None] = None, +) -> tuple[int]: """Finds what the hdim axes are given a 3D (including z) or 4D (including z and time) dataset. Parameters ---------- - field_in: iris cube or xarray dataset + field_in: iris cube or xarray dataarray Input field, can be 3D or 4D vertical_coord: str The name of the vertical coord, or None, which will attempt to find @@ -626,7 +601,6 @@ def find_hdim_axes_3D(field_in, vertical_coord=None, vertical_axis=None): The axes for hdim_1 and hdim_2 """ - from iris import cube as iris_cube if vertical_coord == "auto": _warn_auto_coordinate() @@ -635,91 +609,54 @@ def find_hdim_axes_3D(field_in, vertical_coord=None, vertical_axis=None): if vertical_coord != "auto": raise ValueError("Cannot set both vertical_coord and vertical_axis.") - if type(field_in) is iris_cube.Cube: - return find_hdim_axes_3D_iris(field_in, vertical_coord, vertical_axis) + if type(field_in) is iris.cube.Cube: + return iris_utils.find_hdim_axes_3d(field_in, vertical_coord, vertical_axis) elif type(field_in) is xr.DataArray: raise NotImplementedError("Xarray find_hdim_axes_3D not implemented") else: raise ValueError("Unknown data type: " + type(field_in).__name__) -def find_hdim_axes_3D_iris(field_in, vertical_coord=None, vertical_axis=None): - """Finds what the hdim axes are given a 3D (including z) or - 4D (including z and time) dataset. +def find_axis_from_coord( + variable_arr: Union[iris.cube.Cube, xr.DataArray], coord_name: str +) -> int: + """Finds the axis number in an xarray or iris cube given a coordinate or dimension name. Parameters ---------- - field_in: iris cube - Input field, can be 3D or 4D - vertical_coord: str or None - The name of the vertical coord, or None, which will attempt to find - the vertical coordinate name - vertical_axis: int or None - The axis number of the vertical coordinate, or None. Note - that only one of vertical_axis or vertical_coord can be set. + variable_arr: iris.cube.Cube or xarray.DataArray + Input variable cube + coord_name: str + coordinate or dimension to look for Returns ------- - (hdim_1_axis, hdim_2_axis): (int, int) - The axes for hdim_1 and hdim_2 + axis_number: int + the number of the axis of the given coordinate, or None if the coordinate + is not found in the variable or not a dimensional coordinate """ - if vertical_coord == "auto": - _warn_auto_coordinate() - - if vertical_coord is not None and vertical_axis is not None: - if vertical_coord != "auto": - raise ValueError("Cannot set both vertical_coord and vertical_axis.") - - time_axis = find_axis_from_coord(field_in, "time") - if vertical_axis is not None: - vertical_coord_axis = vertical_axis - vert_coord_found = True + if isinstance(variable_arr, iris.cube.Cube): + return iris_utils.find_axis_from_coord(variable_arr, coord_name) + elif isinstance(variable_arr, xr.DataArray): + raise NotImplementedError( + "xarray version of find_axis_from_coord not implemented." + ) else: - try: - vertical_axis = find_vertical_axis_from_coord( - field_in, vertical_coord=vertical_coord - ) - except ValueError: - vert_coord_found = False - else: - vert_coord_found = True - ndim_vertical = field_in.coord_dims(vertical_axis) - if len(ndim_vertical) > 1: - raise ValueError( - "please specify 1 dimensional vertical coordinate." - " Current vertical coordinates: {0}".format(ndim_vertical) - ) - if len(ndim_vertical) != 0: - vertical_coord_axis = ndim_vertical[0] - else: - # this means the vertical coordinate is an auxiliary coordinate of some kind. - vert_coord_found = False - - if not vert_coord_found: - # if we don't have a vertical coordinate, and we are 3D or lower - # that is okay. - if (field_in.ndim == 3 and time_axis is not None) or field_in.ndim < 3: - vertical_coord_axis = None - else: - raise ValueError("No suitable vertical coordinate found") - # Once we know the vertical coordinate, we can resolve the - # horizontal coordinates - - all_axes = np.arange(0, field_in.ndim) - output_vals = tuple( - all_axes[np.logical_not(np.isin(all_axes, [time_axis, vertical_coord_axis]))] - ) - return output_vals + raise ValueError("variable_arr must be Iris Cube or Xarray DataArray") @irispandas_to_xarray -def detect_latlon_coord_name(in_dataset, latitude_name=None, longitude_name=None): +def detect_latlon_coord_name( + in_dataset: Union[xr.DataArray, iris.cube.Cube], + latitude_name: Union[str, None] = None, + longitude_name: Union[str, None] = None, +) -> tuple[str]: """Function to detect the name of latitude/longitude coordinates Parameters ---------- - in_dataset: iris.cube.Cube, xarray.Dataset, or xarray.Dataarray + in_dataset: iris.cube.Cube or xarray.DataArray Input dataset to detect names from latitude_name: str The name of the latitude coordinate. If None, tries to auto-detect. diff --git a/tobac/utils/internal/iris_utils.py b/tobac/utils/internal/iris_utils.py new file mode 100644 index 00000000..36561799 --- /dev/null +++ b/tobac/utils/internal/iris_utils.py @@ -0,0 +1,156 @@ +"""Internal tobac utilities for iris cubes +The goal will be to, ultimately, remove these when we sunset iris +""" +from __future__ import annotations + +from typing import Union + +import iris +import iris.cube +import numpy as np + +from . import basic as tb_utils_gi + + +def find_axis_from_coord( + variable_cube: iris.cube.Cube, coord_name: str +) -> Union[int, None]: + """Finds the axis number in an iris cube given a coordinate name. + + Parameters + ---------- + variable_cube: iris.cube + Input variable cube + coord_name: str + coordinate to look for + + Returns + ------- + axis_number: int + the number of the axis of the given coordinate, or None if the coordinate + is not found in the cube or not a dimensional coordinate + """ + + list_coord_names = [coord.name() for coord in variable_cube.coords()] + all_matching_axes = list(set(list_coord_names) & {coord_name}) + if ( + len(all_matching_axes) == 1 + and len(variable_cube.coord_dims(all_matching_axes[0])) > 0 + ): + return variable_cube.coord_dims(all_matching_axes[0])[0] + if len(all_matching_axes) > 1: + raise ValueError("Too many axes matched.") + + return None + + +def find_vertical_axis_from_coord( + variable_cube: iris.cube.Cube, vertical_coord: Union[str, None] = None +) -> str: + """Function to find the vertical coordinate in the iris cube + + Parameters + ---------- + variable_cube: iris.cube + Input variable cube, containing a vertical coordinate. + vertical_coord: str + Vertical coordinate name. If None, this function tries to auto-detect. + + Returns + ------- + str + the vertical coordinate name + + Raises + ------ + ValueError + Raised if the vertical coordinate isn't found in the cube. + """ + + list_coord_names = [coord.name() for coord in variable_cube.coords()] + + if vertical_coord is None or vertical_coord == "auto": + # find the intersection + all_vertical_axes = list( + set(list_coord_names) & set(tb_utils_gi.COMMON_VERT_COORDS) + ) + if len(all_vertical_axes) >= 1: + return all_vertical_axes[0] + raise ValueError( + "Cube lacks suitable automatic vertical coordinate (z, model_level_number, altitude, " + "or geopotential_height)" + ) + if vertical_coord in list_coord_names: + return vertical_coord + raise ValueError("Please specify vertical coordinate found in cube") + + +def find_hdim_axes_3d( + field_in: iris.cube.Cube, + vertical_coord: Union[str, None] = None, + vertical_axis: Union[int, None] = None, +) -> tuple[int]: + """Finds what the hdim axes are given a 3D (including z) or + 4D (including z and time) dataset. + + Parameters + ---------- + field_in: iris cube + Input field, can be 3D or 4D + vertical_coord: str or None + The name of the vertical coord, or None, which will attempt to find + the vertical coordinate name + vertical_axis: int or None + The axis number of the vertical coordinate, or None. Note + that only one of vertical_axis or vertical_coord can be set. + + Returns + ------- + (hdim_1_axis, hdim_2_axis): (int, int) + The axes for hdim_1 and hdim_2 + """ + + if vertical_coord is not None and vertical_axis is not None: + if vertical_coord != "auto": + raise ValueError("Cannot set both vertical_coord and vertical_axis.") + + time_axis = find_axis_from_coord(field_in, "time") + if vertical_axis is not None: + vertical_coord_axis = vertical_axis + vert_coord_found = True + else: + try: + vertical_axis = find_vertical_axis_from_coord( + field_in, vertical_coord=vertical_coord + ) + except ValueError: + vert_coord_found = False + else: + vert_coord_found = True + ndim_vertical = field_in.coord_dims(vertical_axis) + if len(ndim_vertical) > 1: + raise ValueError( + "please specify 1 dimensional vertical coordinate." + f" Current vertical coordinates: {ndim_vertical}" + ) + if len(ndim_vertical) != 0: + vertical_coord_axis = ndim_vertical[0] + else: + # this means the vertical coordinate is an auxiliary coordinate of some kind. + vert_coord_found = False + + if not vert_coord_found: + # if we don't have a vertical coordinate, and we are 3D or lower + # that is okay. + if (field_in.ndim == 3 and time_axis is not None) or field_in.ndim < 3: + vertical_coord_axis = None + else: + raise ValueError("No suitable vertical coordinate found") + # Once we know the vertical coordinate, we can resolve the + # horizontal coordinates + + all_axes = np.arange(0, field_in.ndim) + output_vals = tuple( + all_axes[np.logical_not(np.isin(all_axes, [time_axis, vertical_coord_axis]))] + ) + return output_vals diff --git a/tobac/utils/internal/xarray_utils.py b/tobac/utils/internal/xarray_utils.py new file mode 100644 index 00000000..6d37dacc --- /dev/null +++ b/tobac/utils/internal/xarray_utils.py @@ -0,0 +1,50 @@ +"""Internal tobac utilities for xarray datasets/dataarrays +""" +from __future__ import annotations + + +from typing import Union +import xarray as xr +from . import basic as tb_utils_gi + + +def find_vertical_axis_from_coord( + variable_cube: xr.DataArray, + vertical_coord: Union[str, None] = None, +) -> str: + """Function to find the vertical coordinate in the iris cube + + Parameters + ---------- + variable_cube: iris.cube.Cube or xarray.DataArray + Input variable cube, containing a vertical coordinate. + vertical_coord: str + Vertical coordinate name. If None, this function tries to auto-detect. + + Returns + ------- + str + the vertical coordinate name + + Raises + ------ + ValueError + Raised if the vertical coordinate isn't found in the cube. + """ + + list_coord_names = variable_cube.coords + + if vertical_coord is None or vertical_coord == "auto": + # find the intersection + all_vertical_axes = list( + set(list_coord_names) & set(tb_utils_gi.COMMON_VERT_COORDS) + ) + if len(all_vertical_axes) >= 1: + return all_vertical_axes[0] + raise ValueError( + "Cube lacks suitable automatic vertical coordinate (z, model_level_number, " + "altitude, or geopotential_height)" + ) + if vertical_coord in list_coord_names: + return vertical_coord + raise ValueError("Please specify vertical coordinate found in cube")