Skip to content

Commit

Permalink
Merge pull request #341 from freemansw1/internal_utils_reorg
Browse files Browse the repository at this point in the history
Reorganization of internal utilities and adding type hints for internal utils
  • Loading branch information
freemansw1 authored Nov 8, 2023
2 parents 85f8f3a + 689e862 commit 5d4759d
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 134 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def get_requirements(requirements_filename):
"[email protected]",
],
license="BSD-3-Clause License",
packages=[PACKAGE_NAME, PACKAGE_NAME + ".utils"],
packages=[PACKAGE_NAME, PACKAGE_NAME + ".utils", PACKAGE_NAME + ".utils.internal"],
install_requires=get_requirements("requirements.txt"),
test_requires=["pytest"],
zip_safe=False,
Expand Down
1 change: 1 addition & 0 deletions tobac/utils/internal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .basic import *
203 changes: 70 additions & 133 deletions tobac/utils/internal.py → tobac/utils/internal/basic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,25 @@
"""Internal tobac utilities
"""
from __future__ import annotations

import numpy as np
import skimage.measure
import xarray as xr
import iris
import iris.cube
import pandas as pd
import warnings
from . import iris_utils
from . import xarray_utils as xr_utils
from typing import Union, Callable

# list of common vertical coordinates to search for in various functions
COMMON_VERT_COORDS: list[str] = [
"z",
"model_level_number",
"altitude",
"geopotential_height",
]


def _warn_auto_coordinate():
Expand All @@ -17,7 +32,7 @@ def _warn_auto_coordinate():
)


def get_label_props_in_dict(labels):
def get_label_props_in_dict(labels: np.array) -> dict:
"""Function to get the label properties into a dictionary format.
Parameters
Expand All @@ -40,7 +55,7 @@ def get_label_props_in_dict(labels):
return region_properties_dict


def get_indices_of_labels_from_reg_prop_dict(region_property_dict):
def get_indices_of_labels_from_reg_prop_dict(region_property_dict: dict) -> tuple[dict]:
"""Function to get the x, y, and z indices (as well as point count) of all labeled regions.
Parameters
----------
Expand Down Expand Up @@ -94,7 +109,7 @@ def get_indices_of_labels_from_reg_prop_dict(region_property_dict):
return [curr_loc_indices, y_indices, x_indices]


def iris_to_xarray(func):
def iris_to_xarray(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
Iris cubes into xarray DataArrays and converts all outputs with type
xarray DataArrays back into Iris cubes.
Expand Down Expand Up @@ -164,7 +179,7 @@ def wrapper(*args, **kwargs):
return wrapper


def xarray_to_iris(func):
def xarray_to_iris(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
xarray DataArrays into Iris cubes and converts all outputs with type
Iris cubes back into xarray DataArrays.
Expand Down Expand Up @@ -248,7 +263,7 @@ def wrapper(*args, **kwargs):
return wrapper


def irispandas_to_xarray(func):
def irispandas_to_xarray(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
Iris cubes/pandas Dataframes into xarray DataArrays/xarray Datasets and
converts all outputs with the type xarray DataArray/xarray Dataset
Expand Down Expand Up @@ -328,7 +343,7 @@ def wrapper(*args, **kwargs):
return wrapper


def xarray_to_irispandas(func):
def xarray_to_irispandas(func: Callable) -> Callable:
"""Decorator that converts all input of a function that is in the form of
DataArrays/xarray Datasets into xarray Iris cubes/pandas Dataframes and
converts all outputs with the type Iris cubes/pandas Dataframes back into
Expand Down Expand Up @@ -431,7 +446,7 @@ def wrapper(*args, **kwargs):
return wrapper


def njit_if_available(func, **kwargs):
def njit_if_available(func: Callable, **kwargs) -> Callable:
"""Decorator to wrap a function with numba.njit if available.
If numba isn't available, it just returns the function.
Expand All @@ -456,12 +471,15 @@ def njit_if_available(func, **kwargs):
return func


def find_vertical_axis_from_coord(variable_cube, vertical_coord=None):
def find_vertical_axis_from_coord(
variable_cube: Union[iris.cube.Cube, xr.DataArray],
vertical_coord: Union[str, None] = None,
) -> str:
"""Function to find the vertical coordinate in the iris cube
Parameters
----------
variable_cube: iris.cube
variable_cube: iris.cube.Cube or xarray.DataArray
Input variable cube, containing a vertical coordinate.
vertical_coord: str
Vertical coordinate name. If None, this function tries to auto-detect.
Expand All @@ -476,69 +494,21 @@ def find_vertical_axis_from_coord(variable_cube, vertical_coord=None):
ValueError
Raised if the vertical coordinate isn't found in the cube.
"""
list_vertical = [
"z",
"model_level_number",
"altitude",
"geopotential_height",
]

if vertical_coord == "auto":
_warn_auto_coordinate()

if isinstance(variable_cube, iris.cube.Cube):
list_coord_names = [coord.name() for coord in variable_cube.coords()]
elif isinstance(variable_cube, xr.Dataset) or isinstance(
variable_cube, xr.DataArray
):
list_coord_names = variable_cube.coords

if vertical_coord is None or vertical_coord == "auto":
# find the intersection
all_vertical_axes = list(set(list_coord_names) & set(list_vertical))
if len(all_vertical_axes) >= 1:
return all_vertical_axes[0]
else:
raise ValueError(
"Cube lacks suitable automatic vertical coordinate (z, model_level_number, altitude, or geopotential_height)"
)
elif vertical_coord in list_coord_names:
return vertical_coord
else:
raise ValueError("Please specify vertical coordinate found in cube")


def find_axis_from_coord(variable_cube, coord_name):
"""Finds the axis number in an iris cube given a coordinate name.
Parameters
----------
variable_cube: iris.cube
Input variable cube
coord_name: str
coordinate to look for
return iris_utils.find_vertical_axis_from_coord(variable_cube, vertical_coord)
if isinstance(variable_cube, xr.Dataset) or isinstance(variable_cube, xr.DataArray):
return xr_utils.find_vertical_axis_from_coord(variable_cube, vertical_coord)

Returns
-------
axis_number: int
the number of the axis of the given coordinate, or None if the coordinate
is not found in the cube or not a dimensional coordinate
"""

list_coord_names = [coord.name() for coord in variable_cube.coords()]
all_matching_axes = list(set(list_coord_names) & set((coord_name,)))
if (
len(all_matching_axes) == 1
and len(variable_cube.coord_dims(all_matching_axes[0])) > 0
):
return variable_cube.coord_dims(all_matching_axes[0])[0]
elif len(all_matching_axes) > 1:
raise ValueError("Too many axes matched.")
else:
return None
raise ValueError("variable_cube must be xr.DataArray or iris.cube.Cube")


def find_dataframe_vertical_coord(variable_dataframe, vertical_coord=None):
def find_dataframe_vertical_coord(
variable_dataframe: pd.DataFrame, vertical_coord: Union[str, None] = None
) -> str:
"""Function to find the vertical coordinate in the iris cube
Parameters
Expand All @@ -563,8 +533,9 @@ def find_dataframe_vertical_coord(variable_dataframe, vertical_coord=None):
_warn_auto_coordinate()

if vertical_coord is None or vertical_coord == "auto":
list_vertical = ["z", "model_level_number", "altitude", "geopotential_height"]
all_vertical_axes = list(set(variable_dataframe.columns) & set(list_vertical))
all_vertical_axes = list(
set(variable_dataframe.columns) & set(COMMON_VERT_COORDS)
)
if len(all_vertical_axes) == 1:
return all_vertical_axes[0]
else:
Expand All @@ -578,7 +549,7 @@ def find_dataframe_vertical_coord(variable_dataframe, vertical_coord=None):


@njit_if_available
def calc_distance_coords(coords_1, coords_2):
def calc_distance_coords(coords_1: np.array, coords_2: np.array) -> float:
"""Function to calculate the distance between cartesian
coordinate set 1 and coordinate set 2.
Parameters
Expand All @@ -605,13 +576,17 @@ def calc_distance_coords(coords_1, coords_2):
return np.sqrt(np.sum(deltas**2))


def find_hdim_axes_3D(field_in, vertical_coord=None, vertical_axis=None):
def find_hdim_axes_3D(
field_in: Union[iris.cube.Cube, xr.DataArray],
vertical_coord: Union[str, None] = None,
vertical_axis: Union[int, None] = None,
) -> tuple[int]:
"""Finds what the hdim axes are given a 3D (including z) or
4D (including z and time) dataset.
Parameters
----------
field_in: iris cube or xarray dataset
field_in: iris cube or xarray dataarray
Input field, can be 3D or 4D
vertical_coord: str
The name of the vertical coord, or None, which will attempt to find
Expand All @@ -626,7 +601,6 @@ def find_hdim_axes_3D(field_in, vertical_coord=None, vertical_axis=None):
The axes for hdim_1 and hdim_2
"""
from iris import cube as iris_cube

if vertical_coord == "auto":
_warn_auto_coordinate()
Expand All @@ -635,91 +609,54 @@ def find_hdim_axes_3D(field_in, vertical_coord=None, vertical_axis=None):
if vertical_coord != "auto":
raise ValueError("Cannot set both vertical_coord and vertical_axis.")

if type(field_in) is iris_cube.Cube:
return find_hdim_axes_3D_iris(field_in, vertical_coord, vertical_axis)
if type(field_in) is iris.cube.Cube:
return iris_utils.find_hdim_axes_3d(field_in, vertical_coord, vertical_axis)
elif type(field_in) is xr.DataArray:
raise NotImplementedError("Xarray find_hdim_axes_3D not implemented")
else:
raise ValueError("Unknown data type: " + type(field_in).__name__)


def find_hdim_axes_3D_iris(field_in, vertical_coord=None, vertical_axis=None):
"""Finds what the hdim axes are given a 3D (including z) or
4D (including z and time) dataset.
def find_axis_from_coord(
variable_arr: Union[iris.cube.Cube, xr.DataArray], coord_name: str
) -> int:
"""Finds the axis number in an xarray or iris cube given a coordinate or dimension name.
Parameters
----------
field_in: iris cube
Input field, can be 3D or 4D
vertical_coord: str or None
The name of the vertical coord, or None, which will attempt to find
the vertical coordinate name
vertical_axis: int or None
The axis number of the vertical coordinate, or None. Note
that only one of vertical_axis or vertical_coord can be set.
variable_arr: iris.cube.Cube or xarray.DataArray
Input variable cube
coord_name: str
coordinate or dimension to look for
Returns
-------
(hdim_1_axis, hdim_2_axis): (int, int)
The axes for hdim_1 and hdim_2
axis_number: int
the number of the axis of the given coordinate, or None if the coordinate
is not found in the variable or not a dimensional coordinate
"""

if vertical_coord == "auto":
_warn_auto_coordinate()

if vertical_coord is not None and vertical_axis is not None:
if vertical_coord != "auto":
raise ValueError("Cannot set both vertical_coord and vertical_axis.")

time_axis = find_axis_from_coord(field_in, "time")
if vertical_axis is not None:
vertical_coord_axis = vertical_axis
vert_coord_found = True
if isinstance(variable_arr, iris.cube.Cube):
return iris_utils.find_axis_from_coord(variable_arr, coord_name)
elif isinstance(variable_arr, xr.DataArray):
raise NotImplementedError(
"xarray version of find_axis_from_coord not implemented."
)
else:
try:
vertical_axis = find_vertical_axis_from_coord(
field_in, vertical_coord=vertical_coord
)
except ValueError:
vert_coord_found = False
else:
vert_coord_found = True
ndim_vertical = field_in.coord_dims(vertical_axis)
if len(ndim_vertical) > 1:
raise ValueError(
"please specify 1 dimensional vertical coordinate."
" Current vertical coordinates: {0}".format(ndim_vertical)
)
if len(ndim_vertical) != 0:
vertical_coord_axis = ndim_vertical[0]
else:
# this means the vertical coordinate is an auxiliary coordinate of some kind.
vert_coord_found = False

if not vert_coord_found:
# if we don't have a vertical coordinate, and we are 3D or lower
# that is okay.
if (field_in.ndim == 3 and time_axis is not None) or field_in.ndim < 3:
vertical_coord_axis = None
else:
raise ValueError("No suitable vertical coordinate found")
# Once we know the vertical coordinate, we can resolve the
# horizontal coordinates

all_axes = np.arange(0, field_in.ndim)
output_vals = tuple(
all_axes[np.logical_not(np.isin(all_axes, [time_axis, vertical_coord_axis]))]
)
return output_vals
raise ValueError("variable_arr must be Iris Cube or Xarray DataArray")


@irispandas_to_xarray
def detect_latlon_coord_name(in_dataset, latitude_name=None, longitude_name=None):
def detect_latlon_coord_name(
in_dataset: Union[xr.DataArray, iris.cube.Cube],
latitude_name: Union[str, None] = None,
longitude_name: Union[str, None] = None,
) -> tuple[str]:
"""Function to detect the name of latitude/longitude coordinates
Parameters
----------
in_dataset: iris.cube.Cube, xarray.Dataset, or xarray.Dataarray
in_dataset: iris.cube.Cube or xarray.DataArray
Input dataset to detect names from
latitude_name: str
The name of the latitude coordinate. If None, tries to auto-detect.
Expand Down
Loading

0 comments on commit 5d4759d

Please sign in to comment.