From e8db39e1aea3c7c5e08490356e91545d709ff712 Mon Sep 17 00:00:00 2001 From: Nicholas Loveday Date: Thu, 14 Nov 2024 15:35:19 +1100 Subject: [PATCH] discretise now takes operators --- src/scores/processing/discretise.py | 98 +++++++++++++++++++++++------ tests/processing/test_discretise.py | 31 +++++++++ 2 files changed, 111 insertions(+), 18 deletions(-) diff --git a/src/scores/processing/discretise.py b/src/scores/processing/discretise.py index 25de0ac0..2970b9c4 100644 --- a/src/scores/processing/discretise.py +++ b/src/scores/processing/discretise.py @@ -2,7 +2,7 @@ import operator from collections.abc import Iterable -from typing import Optional, Union +from typing import Callable, Optional, Union import numpy as np import xarray as xr @@ -20,9 +20,17 @@ # This is because we wish to test within a tolerance. EQUALITY_MODES = {"==": (operator.le), "!=": (operator.gt)} +OPERATOR_MODE_MESSAGE = ( + "or the following operators: operator.ge, operator.gt, operator.le, operator.lt, operator.eq, operator.ne" +) + def comparative_discretise( - data: XarrayLike, comparison: Union[xr.DataArray, float, int], mode: str, *, abs_tolerance: Optional[float] = None + data: XarrayLike, + comparison: Union[xr.DataArray, float, int], + mode: Union[Callable, str], + *, + abs_tolerance: Optional[float] = None, ) -> XarrayLike: """ Converts the values of `data` to 0 or 1 based on how they relate to the specified @@ -32,8 +40,22 @@ def comparative_discretise( data: The data to convert to discrete values. comparison: The values to which to compare `data`. mode: Specifies the required relation of `data` to `thresholds` - for a value to fall in the 'event' category (i.e. assigned to 1). + for a value to fall in the 'event' category (i.e. assigned to 1). The mode + can be a string or a Python operator function from the list below. Allowed modes are: + + - operator.ge. Where values in `data` greater than or equal to the + corresponding threshold are assigned as 1. + - operator.gt. Where values in `data` greater than the + corresponding threshold are assigned as 1. + - operator.le. Where values in `data` less than or equal to the + corresponding threshold are assigned as 1. + - operator.lt. Where values in `data` less than the + corresponding threshold are assigned as 1. + - operator.eq values in `data` equal to the corresponding threshold + are assigned as 1 + - operator.ne values in `data` not equal to the corresponding threshold + are assigned as 1. - '>=' values in `data` greater than or equal to the corresponding threshold are assigned as 1. - '>' values in `data` greater than the corresponding threshold @@ -83,10 +105,23 @@ def comparative_discretise( elif mode in EQUALITY_MODES: operator_func = EQUALITY_MODES[mode] discrete_data = operator_func(abs(data - comparison), abs_tolerance).where(notnull_mask) + elif mode is operator.eq: + # Instead of `operater.eq` we use `operator.le` because we wish to test within a tolerance. + discrete_data = operator.le(abs(data - comparison), abs_tolerance).where(notnull_mask) + elif mode is operator.ne: + # Instead of `operater.ne`, we use `operator.gt` because we wish to test within a tolerance. + discrete_data = operator.gt(abs(data - comparison), abs_tolerance).where(notnull_mask) + elif mode in [operator.lt, operator.le, operator.gt, operator.ge]: + if mode in [operator.lt, operator.ge]: + factor = -1 + else: + factor = 1 + discrete_data = mode(data, comparison + (abs_tolerance * factor)).where(notnull_mask) else: raise ValueError( f"'{mode}' is not a valid mode. Available modes are: " - f"{sorted(INEQUALITY_MODES) + sorted(EQUALITY_MODES)}" + f"{sorted(INEQUALITY_MODES) + sorted(EQUALITY_MODES)} " + "{OPERATOR_MODE_MESSAGE}" ) discrete_data.attrs["discretisation_tolerance"] = abs_tolerance discrete_data.attrs["discretisation_mode"] = mode @@ -97,7 +132,7 @@ def comparative_discretise( def binary_discretise( data: XarrayLike, thresholds: Optional[FlexibleDimensionTypes], - mode: str, + mode: Union[Callable, str], *, # Force keywords arguments to be keyword-only abs_tolerance: Optional[float] = None, autosqueeze: Optional[bool] = False, @@ -110,9 +145,22 @@ def binary_discretise( data: The data to convert to discrete values. thresholds: Threshold(s) at which to convert the values of `data` to 0 or 1. mode: Specifies the required relation of `data` to `thresholds` - for a value to fall in the 'event' category (i.e. assigned to 1). + for a value to fall in the 'event' category (i.e. assigned to 1). The mode + can be a string or a Python operator function from the list below. Allowed modes are: + - operator.ge. Where values in `data` greater than or equal to the + corresponding threshold are assigned as 1. + - operator.gt. Where values in `data` greater than the + corresponding threshold are assigned as 1. + - operator.le. Where values in `data` less than or equal to the + corresponding threshold are assigned as 1. + - operator.lt. Where values in `data` less than the + corresponding threshold are assigned as 1. + - operator.eq values in `data` equal to the corresponding threshold + are assigned as 1 + - operator.ne values in `data` not equal to the corresponding threshold + are assigned as 1. - '>=' values in `data` greater than or equal to the \ corresponding threshold are assigned as 1. - '>' values in `data` greater than the corresponding threshold \ @@ -205,7 +253,7 @@ def proportion_exceeding( def binary_discretise_proportion( data: XarrayLike, thresholds: Iterable, - mode: str, + mode: Union[Callable, str], *, # Force keywords arguments to be keyword-only reduce_dims: Optional[FlexibleDimensionTypes] = None, preserve_dims: Optional[FlexibleDimensionTypes] = None, @@ -224,21 +272,34 @@ def binary_discretise_proportion( thresholds: The proportion of values equal to or exceeding these thresholds will be calculated. mode: Specifies the required relation of `data` to `thresholds` - for a value to fall in the 'event' category (i.e. assigned to 1). + for a value to fall in the 'event' category (i.e. assigned to 1). The mode + can be a string or a Python operator function from the list below. Allowed modes are: + - operator.ge. Where values in `data` greater than or equal to the + corresponding threshold are assigned as 1. + - operator.gt. Where values in `data` greater than the + corresponding threshold are assigned as 1. + - operator.le. Where values in `data` less than or equal to the + corresponding threshold are assigned as 1. + - operator.lt. Where values in `data` less than the + corresponding threshold are assigned as 1. + - operator.eq values in `data` equal to the corresponding threshold + are assigned as 1 + - operator.ne values in `data` not equal to the corresponding threshold + are assigned as 1. - '>=' values in `data` greater than or equal to the - corresponding threshold are assigned as 1. + corresponding threshold are assigned as 1. - '>' values in `data` greater than the corresponding threshold - are assigned as 1. + are assigned as 1. - '<=' values in `data` less than or equal to the corresponding - threshold are assigned as 1. + threshold are assigned as 1. - '<' values in `data` less than the corresponding threshold - are assigned as 1. + are assigned as 1. - '==' values in `data` equal to the corresponding threshold - are assigned as 1 + are assigned as 1 - '!=' values in `data` not equal to the corresponding threshold - are assigned as 1. + are assigned as 1. reduce_dims: Dimensions to reduce. preserve_dims: Dimensions to preserve. abs_tolerance: If supplied, values in data that are @@ -257,10 +318,11 @@ def binary_discretise_proportion( satisfy the relationship to `thresholds` as specified by `mode`. Examples: - + >>> import operator + >>> import xarray as xr + >>> from scores.processing import binary_discretise_proportion >>> data = xr.DataArray([0, 0.5, 0.5, 1]) - - >>> _binary_discretise_proportion(data, [0, 0.5, 1], '==') + >>> binary_discretise_proportion(data, [0, 0.5, 1], operator.eq) array([ 0.25, 0.5 , 0.25]) Coordinates: @@ -269,7 +331,7 @@ def binary_discretise_proportion( discretisation_tolerance: 0 discretisation_mode: == - >>> _binary_discretise_proportion(data, [0, 0.5, 1], '>=') + >>> binary_discretise_proportion(data, [0, 0.5, 1], operator.ge) array([ 1. , 0.75, 0.25]) Coordinates: diff --git a/tests/processing/test_discretise.py b/tests/processing/test_discretise.py index 1c62c4aa..c853a416 100644 --- a/tests/processing/test_discretise.py +++ b/tests/processing/test_discretise.py @@ -1,5 +1,7 @@ """Tests for scores.processing.discretise""" +import operator + import numpy as np import pytest import xarray as xr @@ -120,28 +122,40 @@ ################################################################ # 28. mode='>=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">=", 1e-8, xtd.EXP_CDIS_GE0), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ge, 1e-8, xtd.EXP_CDIS_GE0), # 29. mode='>=', tolerance=0 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">=", None, xtd.EXP_CDIS_GE1), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ge, None, xtd.EXP_CDIS_GE1), # 30. mode='>', tolerance=1e-8 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">", 1e-8, xtd.EXP_CDIS_GT0), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.gt, 1e-8, xtd.EXP_CDIS_GT0), # 31. mode='>', tolerance=0 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), ">", None, xtd.EXP_CDIS_GT1), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.gt, None, xtd.EXP_CDIS_GT1), # 32. mode='<=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<=", 1e-8, xtd.EXP_CDIS_LE0), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.le, 1e-8, xtd.EXP_CDIS_LE0), # 33. mode='<=', tolerance=0 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<=", None, xtd.EXP_CDIS_LE1), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.le, None, xtd.EXP_CDIS_LE1), # 34. mode='<', tolerance=1e-8 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<", 1e-8, xtd.EXP_CDIS_LT0), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.lt, 1e-8, xtd.EXP_CDIS_LT0), # 35. mode='<', tolerance=0 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "<", None, xtd.EXP_CDIS_LT1), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.lt, None, xtd.EXP_CDIS_LT1), # 36. mode='==', tolerance=1e-8 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "==", 1e-8, xtd.EXP_CDIS_EQ0), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.eq, 1e-8, xtd.EXP_CDIS_EQ0), # 37. mode='==', tolerance=0 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "==", None, xtd.EXP_CDIS_EQ1), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.eq, None, xtd.EXP_CDIS_EQ1), # 38. mode='!=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "!=", 1e-8, xtd.EXP_CDIS_NE0), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ne, 1e-8, xtd.EXP_CDIS_NE0), # 39. mode='!=', tolerance=0 (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), "!=", None, xtd.EXP_CDIS_NE1), + (xtd.DATA_5X1_POINT4, xr.DataArray(0.4), operator.ne, None, xtd.EXP_CDIS_NE1), ################################## # 0-D Integer & float comparison # ################################## @@ -288,28 +302,40 @@ def test_comparative_discretise_raises(data, comparison, mode, abs_tolerance, er # SMOKE TESTS FOR ALL MODES: ['<', '>', '>=', '<=', '==', '!='] # 2. mode='>=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], ">=", 1e-8, True, xtd.EXP_DIS_GE0), + (xtd.DATA_5X1_POINT4, [0.4], operator.ge, 1e-8, True, xtd.EXP_DIS_GE0), # 3. mode='>=', tolerance=0 (xtd.DATA_5X1_POINT4, [0.4], ">=", None, True, xtd.EXP_DIS_GE1), + (xtd.DATA_5X1_POINT4, [0.4], operator.ge, None, True, xtd.EXP_DIS_GE1), # 4. mode='>', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], ">", 1e-8, True, xtd.EXP_DIS_GT0), + (xtd.DATA_5X1_POINT4, [0.4], operator.gt, 1e-8, True, xtd.EXP_DIS_GT0), # 5. mode='>', tolerance=0 (xtd.DATA_5X1_POINT4, [0.4], ">", None, True, xtd.EXP_DIS_GT1), + (xtd.DATA_5X1_POINT4, [0.4], operator.gt, None, True, xtd.EXP_DIS_GT1), # 6. mode='<=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "<=", 1e-8, True, xtd.EXP_DIS_LE0), + (xtd.DATA_5X1_POINT4, [0.4], operator.le, 1e-8, True, xtd.EXP_DIS_LE0), # 7. mode='<=', tolerance=0 (xtd.DATA_5X1_POINT4, [0.4], "<=", None, True, xtd.EXP_DIS_LE1), + (xtd.DATA_5X1_POINT4, [0.4], operator.le, None, True, xtd.EXP_DIS_LE1), # 8. mode='<', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "<", 1e-8, True, xtd.EXP_DIS_LT0), + (xtd.DATA_5X1_POINT4, [0.4], operator.lt, 1e-8, True, xtd.EXP_DIS_LT0), # 9. mode='<', tolerance=0 (xtd.DATA_5X1_POINT4, [0.4], "<", None, True, xtd.EXP_DIS_LT1), + (xtd.DATA_5X1_POINT4, [0.4], operator.lt, None, True, xtd.EXP_DIS_LT1), # 10. mode='==', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "==", 1e-8, True, xtd.EXP_DIS_EQ0), + (xtd.DATA_5X1_POINT4, [0.4], operator.eq, 1e-8, True, xtd.EXP_DIS_EQ0), # 11. mode='==', tolerance=0 (xtd.DATA_5X1_POINT4, [0.4], "==", None, True, xtd.EXP_DIS_EQ1), + (xtd.DATA_5X1_POINT4, [0.4], operator.eq, None, True, xtd.EXP_DIS_EQ1), # 12. mode='!=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "!=", 1e-8, True, xtd.EXP_DIS_NE0), + (xtd.DATA_5X1_POINT4, [0.4], operator.ne, 1e-8, True, xtd.EXP_DIS_NE0), # 13. mode='!=', tolerance=0 (xtd.DATA_5X1_POINT4, [0.4], "!=", None, True, xtd.EXP_DIS_NE1), + (xtd.DATA_5X1_POINT4, [0.4], operator.ne, None, True, xtd.EXP_DIS_NE1), # Dataset input # 14. 1-D data, ( @@ -484,15 +510,20 @@ def test_proportion_exceeding(data, thresholds, reduce_dims, preserve_dims, expe # SMOKE TESTS FOR OTHER INEQUALITIES ['<', '>', '>='] # 4. mode='>', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], ">", None, None, 1e-8, True, xtd.EXP_BDP_4), + (xtd.DATA_5X1_POINT4, [0.4], operator.gt, None, None, 1e-8, True, xtd.EXP_BDP_4), # 5. mode='<=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "<=", None, None, 1e-8, True, xtd.EXP_BDP_5), + (xtd.DATA_5X1_POINT4, [0.4], operator.le, None, None, 1e-8, True, xtd.EXP_BDP_5), # 6. mode='<', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "<", None, None, 1e-8, True, xtd.EXP_BDP_6), + (xtd.DATA_5X1_POINT4, [0.4], operator.lt, None, None, 1e-8, True, xtd.EXP_BDP_6), # SMOKE TESTS FOR EQUALITY MODES # 7. mode='==', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "==", None, None, 1e-8, True, xtd.EXP_BDP_7), + (xtd.DATA_5X1_POINT4, [0.4], operator.eq, None, None, 1e-8, True, xtd.EXP_BDP_7), # 8. mode='!=', tolerance=1e-8 (xtd.DATA_5X1_POINT4, [0.4], "!=", None, None, 1e-8, True, xtd.EXP_BDP_8), + (xtd.DATA_5X1_POINT4, [0.4], operator.ne, None, None, 1e-8, True, xtd.EXP_BDP_8), # 9. Dataset input ( xr.Dataset({"zero": xtd.DATA_4X1, "one": xtd.DATA_4X1_NAN}),