Skip to content

Commit

Permalink
discretise now takes operators
Browse files Browse the repository at this point in the history
  • Loading branch information
nicholasloveday committed Nov 14, 2024
1 parent 866a2f1 commit e8db39e
Show file tree
Hide file tree
Showing 2 changed files with 111 additions and 18 deletions.
98 changes: 80 additions & 18 deletions src/scores/processing/discretise.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import operator
from collections.abc import Iterable
from typing import Optional, Union
from typing import Callable, Optional, Union

import numpy as np
import xarray as xr
Expand All @@ -20,9 +20,17 @@
# This is because we wish to test within a tolerance.
EQUALITY_MODES = {"==": (operator.le), "!=": (operator.gt)}

OPERATOR_MODE_MESSAGE = (
"or the following operators: operator.ge, operator.gt, operator.le, operator.lt, operator.eq, operator.ne"
)


def comparative_discretise(
data: XarrayLike, comparison: Union[xr.DataArray, float, int], mode: str, *, abs_tolerance: Optional[float] = None
data: XarrayLike,
comparison: Union[xr.DataArray, float, int],
mode: Union[Callable, str],
*,
abs_tolerance: Optional[float] = None,
) -> XarrayLike:
"""
Converts the values of `data` to 0 or 1 based on how they relate to the specified
Expand All @@ -32,8 +40,22 @@ def comparative_discretise(
data: The data to convert to discrete values.
comparison: The values to which to compare `data`.
mode: Specifies the required relation of `data` to `thresholds`
for a value to fall in the 'event' category (i.e. assigned to 1).
for a value to fall in the 'event' category (i.e. assigned to 1). The mode
can be a string or a Python operator function from the list below.
Allowed modes are:
- operator.ge. Where values in `data` greater than or equal to the
corresponding threshold are assigned as 1.
- operator.gt. Where values in `data` greater than the
corresponding threshold are assigned as 1.
- operator.le. Where values in `data` less than or equal to the
corresponding threshold are assigned as 1.
- operator.lt. Where values in `data` less than the
corresponding threshold are assigned as 1.
- operator.eq values in `data` equal to the corresponding threshold
are assigned as 1
- operator.ne values in `data` not equal to the corresponding threshold
are assigned as 1.
- '>=' values in `data` greater than or equal to the
corresponding threshold are assigned as 1.
- '>' values in `data` greater than the corresponding threshold
Expand Down Expand Up @@ -83,10 +105,23 @@ def comparative_discretise(
elif mode in EQUALITY_MODES:
operator_func = EQUALITY_MODES[mode]
discrete_data = operator_func(abs(data - comparison), abs_tolerance).where(notnull_mask)
elif mode is operator.eq:
# Instead of `operater.eq` we use `operator.le` because we wish to test within a tolerance.
discrete_data = operator.le(abs(data - comparison), abs_tolerance).where(notnull_mask)
elif mode is operator.ne:
# Instead of `operater.ne`, we use `operator.gt` because we wish to test within a tolerance.
discrete_data = operator.gt(abs(data - comparison), abs_tolerance).where(notnull_mask)
elif mode in [operator.lt, operator.le, operator.gt, operator.ge]:
if mode in [operator.lt, operator.ge]:
factor = -1
else:
factor = 1
discrete_data = mode(data, comparison + (abs_tolerance * factor)).where(notnull_mask)
else:
raise ValueError(
f"'{mode}' is not a valid mode. Available modes are: "
f"{sorted(INEQUALITY_MODES) + sorted(EQUALITY_MODES)}"
f"{sorted(INEQUALITY_MODES) + sorted(EQUALITY_MODES)} "
"{OPERATOR_MODE_MESSAGE}"
)
discrete_data.attrs["discretisation_tolerance"] = abs_tolerance
discrete_data.attrs["discretisation_mode"] = mode
Expand All @@ -97,7 +132,7 @@ def comparative_discretise(
def binary_discretise(
data: XarrayLike,
thresholds: Optional[FlexibleDimensionTypes],
mode: str,
mode: Union[Callable, str],
*, # Force keywords arguments to be keyword-only
abs_tolerance: Optional[float] = None,
autosqueeze: Optional[bool] = False,
Expand All @@ -110,9 +145,22 @@ def binary_discretise(
data: The data to convert to discrete values.
thresholds: Threshold(s) at which to convert the values of `data` to 0 or 1.
mode: Specifies the required relation of `data` to `thresholds`
for a value to fall in the 'event' category (i.e. assigned to 1).
for a value to fall in the 'event' category (i.e. assigned to 1). The mode
can be a string or a Python operator function from the list below.
Allowed modes are:
- operator.ge. Where values in `data` greater than or equal to the
corresponding threshold are assigned as 1.
- operator.gt. Where values in `data` greater than the
corresponding threshold are assigned as 1.
- operator.le. Where values in `data` less than or equal to the
corresponding threshold are assigned as 1.
- operator.lt. Where values in `data` less than the
corresponding threshold are assigned as 1.
- operator.eq values in `data` equal to the corresponding threshold
are assigned as 1
- operator.ne values in `data` not equal to the corresponding threshold
are assigned as 1.
- '>=' values in `data` greater than or equal to the \
corresponding threshold are assigned as 1.
- '>' values in `data` greater than the corresponding threshold \
Expand Down Expand Up @@ -205,7 +253,7 @@ def proportion_exceeding(
def binary_discretise_proportion(
data: XarrayLike,
thresholds: Iterable,
mode: str,
mode: Union[Callable, str],
*, # Force keywords arguments to be keyword-only
reduce_dims: Optional[FlexibleDimensionTypes] = None,
preserve_dims: Optional[FlexibleDimensionTypes] = None,
Expand All @@ -224,21 +272,34 @@ def binary_discretise_proportion(
thresholds: The proportion of values
equal to or exceeding these thresholds will be calculated.
mode: Specifies the required relation of `data` to `thresholds`
for a value to fall in the 'event' category (i.e. assigned to 1).
for a value to fall in the 'event' category (i.e. assigned to 1). The mode
can be a string or a Python operator function from the list below.
Allowed modes are:
- operator.ge. Where values in `data` greater than or equal to the
corresponding threshold are assigned as 1.
- operator.gt. Where values in `data` greater than the
corresponding threshold are assigned as 1.
- operator.le. Where values in `data` less than or equal to the
corresponding threshold are assigned as 1.
- operator.lt. Where values in `data` less than the
corresponding threshold are assigned as 1.
- operator.eq values in `data` equal to the corresponding threshold
are assigned as 1
- operator.ne values in `data` not equal to the corresponding threshold
are assigned as 1.
- '>=' values in `data` greater than or equal to the
corresponding threshold are assigned as 1.
corresponding threshold are assigned as 1.
- '>' values in `data` greater than the corresponding threshold
are assigned as 1.
are assigned as 1.
- '<=' values in `data` less than or equal to the corresponding
threshold are assigned as 1.
threshold are assigned as 1.
- '<' values in `data` less than the corresponding threshold
are assigned as 1.
are assigned as 1.
- '==' values in `data` equal to the corresponding threshold
are assigned as 1
are assigned as 1
- '!=' values in `data` not equal to the corresponding threshold
are assigned as 1.
are assigned as 1.
reduce_dims: Dimensions to reduce.
preserve_dims: Dimensions to preserve.
abs_tolerance: If supplied, values in data that are
Expand All @@ -257,10 +318,11 @@ def binary_discretise_proportion(
satisfy the relationship to `thresholds` as specified by `mode`.
Examples:
>>> import operator
>>> import xarray as xr
>>> from scores.processing import binary_discretise_proportion
>>> data = xr.DataArray([0, 0.5, 0.5, 1])
>>> _binary_discretise_proportion(data, [0, 0.5, 1], '==')
>>> binary_discretise_proportion(data, [0, 0.5, 1], operator.eq)
<xarray.DataArray (threshold: 3)>
array([ 0.25, 0.5 , 0.25])
Coordinates:
Expand All @@ -269,7 +331,7 @@ def binary_discretise_proportion(
discretisation_tolerance: 0
discretisation_mode: ==
>>> _binary_discretise_proportion(data, [0, 0.5, 1], '>=')
>>> binary_discretise_proportion(data, [0, 0.5, 1], operator.ge)
<xarray.DataArray (threshold: 3)>
array([ 1. , 0.75, 0.25])
Coordinates:
Expand Down
31 changes: 31 additions & 0 deletions tests/processing/test_discretise.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for scores.processing.discretise"""

import operator

import numpy as np
import pytest
import xarray as xr
Expand Down Expand Up @@ -120,28 +122,40 @@
################################################################
# 28. mode='>=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">=", 1e-8, xtd.EXP_CDIS_GE0),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ge, 1e-8, xtd.EXP_CDIS_GE0),
# 29. mode='>=', tolerance=0
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">=", None, xtd.EXP_CDIS_GE1),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ge, None, xtd.EXP_CDIS_GE1),
# 30. mode='>', tolerance=1e-8
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">", 1e-8, xtd.EXP_CDIS_GT0),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.gt, 1e-8, xtd.EXP_CDIS_GT0),
# 31. mode='>', tolerance=0
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">", None, xtd.EXP_CDIS_GT1),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.gt, None, xtd.EXP_CDIS_GT1),
# 32. mode='<=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<=", 1e-8, xtd.EXP_CDIS_LE0),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.le, 1e-8, xtd.EXP_CDIS_LE0),
# 33. mode='<=', tolerance=0
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<=", None, xtd.EXP_CDIS_LE1),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.le, None, xtd.EXP_CDIS_LE1),
# 34. mode='<', tolerance=1e-8
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<", 1e-8, xtd.EXP_CDIS_LT0),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.lt, 1e-8, xtd.EXP_CDIS_LT0),
# 35. mode='<', tolerance=0
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<", None, xtd.EXP_CDIS_LT1),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.lt, None, xtd.EXP_CDIS_LT1),
# 36. mode='==', tolerance=1e-8
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "==", 1e-8, xtd.EXP_CDIS_EQ0),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.eq, 1e-8, xtd.EXP_CDIS_EQ0),
# 37. mode='==', tolerance=0
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "==", None, xtd.EXP_CDIS_EQ1),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.eq, None, xtd.EXP_CDIS_EQ1),
# 38. mode='!=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "!=", 1e-8, xtd.EXP_CDIS_NE0),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ne, 1e-8, xtd.EXP_CDIS_NE0),
# 39. mode='!=', tolerance=0
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "!=", None, xtd.EXP_CDIS_NE1),
(xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ne, None, xtd.EXP_CDIS_NE1),
##################################
# 0-D Integer & float comparison #
##################################
Expand Down Expand Up @@ -288,28 +302,40 @@ def test_comparative_discretise_raises(data, comparison, mode, abs_tolerance, er
# SMOKE TESTS FOR ALL MODES: ['<', '>', '>=', '<=', '==', '!=']
# 2. mode='>=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], ">=", 1e-8, True, xtd.EXP_DIS_GE0),
(xtd.DATA_5X1_POINT4, [0.4], operator.ge, 1e-8, True, xtd.EXP_DIS_GE0),
# 3. mode='>=', tolerance=0
(xtd.DATA_5X1_POINT4, [0.4], ">=", None, True, xtd.EXP_DIS_GE1),
(xtd.DATA_5X1_POINT4, [0.4], operator.ge, None, True, xtd.EXP_DIS_GE1),
# 4. mode='>', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], ">", 1e-8, True, xtd.EXP_DIS_GT0),
(xtd.DATA_5X1_POINT4, [0.4], operator.gt, 1e-8, True, xtd.EXP_DIS_GT0),
# 5. mode='>', tolerance=0
(xtd.DATA_5X1_POINT4, [0.4], ">", None, True, xtd.EXP_DIS_GT1),
(xtd.DATA_5X1_POINT4, [0.4], operator.gt, None, True, xtd.EXP_DIS_GT1),
# 6. mode='<=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "<=", 1e-8, True, xtd.EXP_DIS_LE0),
(xtd.DATA_5X1_POINT4, [0.4], operator.le, 1e-8, True, xtd.EXP_DIS_LE0),
# 7. mode='<=', tolerance=0
(xtd.DATA_5X1_POINT4, [0.4], "<=", None, True, xtd.EXP_DIS_LE1),
(xtd.DATA_5X1_POINT4, [0.4], operator.le, None, True, xtd.EXP_DIS_LE1),
# 8. mode='<', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "<", 1e-8, True, xtd.EXP_DIS_LT0),
(xtd.DATA_5X1_POINT4, [0.4], operator.lt, 1e-8, True, xtd.EXP_DIS_LT0),
# 9. mode='<', tolerance=0
(xtd.DATA_5X1_POINT4, [0.4], "<", None, True, xtd.EXP_DIS_LT1),
(xtd.DATA_5X1_POINT4, [0.4], operator.lt, None, True, xtd.EXP_DIS_LT1),
# 10. mode='==', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "==", 1e-8, True, xtd.EXP_DIS_EQ0),
(xtd.DATA_5X1_POINT4, [0.4], operator.eq, 1e-8, True, xtd.EXP_DIS_EQ0),
# 11. mode='==', tolerance=0
(xtd.DATA_5X1_POINT4, [0.4], "==", None, True, xtd.EXP_DIS_EQ1),
(xtd.DATA_5X1_POINT4, [0.4], operator.eq, None, True, xtd.EXP_DIS_EQ1),
# 12. mode='!=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "!=", 1e-8, True, xtd.EXP_DIS_NE0),
(xtd.DATA_5X1_POINT4, [0.4], operator.ne, 1e-8, True, xtd.EXP_DIS_NE0),
# 13. mode='!=', tolerance=0
(xtd.DATA_5X1_POINT4, [0.4], "!=", None, True, xtd.EXP_DIS_NE1),
(xtd.DATA_5X1_POINT4, [0.4], operator.ne, None, True, xtd.EXP_DIS_NE1),
# Dataset input
# 14. 1-D data,
(
Expand Down Expand Up @@ -484,15 +510,20 @@ def test_proportion_exceeding(data, thresholds, reduce_dims, preserve_dims, expe
# SMOKE TESTS FOR OTHER INEQUALITIES ['<', '>', '>=']
# 4. mode='>', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], ">", None, None, 1e-8, True, xtd.EXP_BDP_4),
(xtd.DATA_5X1_POINT4, [0.4], operator.gt, None, None, 1e-8, True, xtd.EXP_BDP_4),
# 5. mode='<=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "<=", None, None, 1e-8, True, xtd.EXP_BDP_5),
(xtd.DATA_5X1_POINT4, [0.4], operator.le, None, None, 1e-8, True, xtd.EXP_BDP_5),
# 6. mode='<', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "<", None, None, 1e-8, True, xtd.EXP_BDP_6),
(xtd.DATA_5X1_POINT4, [0.4], operator.lt, None, None, 1e-8, True, xtd.EXP_BDP_6),
# SMOKE TESTS FOR EQUALITY MODES
# 7. mode='==', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "==", None, None, 1e-8, True, xtd.EXP_BDP_7),
(xtd.DATA_5X1_POINT4, [0.4], operator.eq, None, None, 1e-8, True, xtd.EXP_BDP_7),
# 8. mode='!=', tolerance=1e-8
(xtd.DATA_5X1_POINT4, [0.4], "!=", None, None, 1e-8, True, xtd.EXP_BDP_8),
(xtd.DATA_5X1_POINT4, [0.4], operator.ne, None, None, 1e-8, True, xtd.EXP_BDP_8),
# 9. Dataset input
(
xr.Dataset({"zero": xtd.DATA_4X1, "one": xtd.DATA_4X1_NAN}),
Expand Down

0 comments on commit e8db39e

Please sign in to comment.