From c1d0554457d201d5496461540e1a7545d21c4940 Mon Sep 17 00:00:00 2001 From: Daniel Wenz <43881800+WenzDaniel@users.noreply.github.com> Date: Fri, 11 Aug 2023 08:14:43 +0200 Subject: [PATCH] Add small selection functions (#746) --- strax/context.py | 7 ++++++- strax/utils.py | 25 ++++++++++++++++++------- tests/test_utils.py | 24 +++++++++++++++++++++++- 3 files changed, 47 insertions(+), 9 deletions(-) diff --git a/strax/context.py b/strax/context.py index 67c97934d..2b449060f 100644 --- a/strax/context.py +++ b/strax/context.py @@ -1205,6 +1205,7 @@ def get_iter(self, run_id: str, seconds_range=None, time_within=None, time_selection='fully_contained', + selection=None, selection_str=None, keep_columns=None, drop_columns=None, @@ -1306,6 +1307,7 @@ def get_iter(self, run_id: str, result.data = strax.apply_selection( result.data, selection_str=selection_str, + selection=selection, keep_columns=keep_columns, drop_columns=drop_columns, time_range=time_range, @@ -2110,7 +2112,10 @@ def add_method(cls, f): select_docs = """ -:param selection_str: Query string or sequence of strings to apply. +:param selection: Query string, sequence of strings, or simple function to apply. + The function must take a single argument which represents the structure + numpy array of the loaded data. +:param selection_str: Same as selection (deprecated) :param keep_columns: Array field/dataframe column names to keep. Useful to reduce amount of data in memory. (You can only specify either keep or drop column.) diff --git a/strax/utils.py b/strax/utils.py index 20f86a88a..ad6f5a266 100644 --- a/strax/utils.py +++ b/strax/utils.py @@ -611,6 +611,7 @@ def iter_chunk_meta(md): @export def apply_selection(x, + selection=None, selection_str=None, keep_columns=None, drop_columns=None, @@ -619,7 +620,8 @@ def apply_selection(x, """Return x after applying selections :param x: Numpy structured array - :param selection_str: Query string or sequence of strings to apply. + :param selection: Query string, sequence of strings, or simple function to apply. + :param selection_str: Same as selection (deprecated) :param time_range: (start, stop) range to load, in ns since the epoch :param keep_columns: Field names of the columns to keep. :param drop_columns: Field names of the columns to drop. @@ -649,12 +651,21 @@ def apply_selection(x, raise ValueError(f"Unknown time_selection {time_selection}") if selection_str: - if isinstance(selection_str, (list, tuple)): - selection_str = ' & '.join(f'({x})' for x in selection_str) - - mask = numexpr.evaluate(selection_str, local_dict={ - fn: x[fn] - for fn in x.dtype.names}) + warn('The option "selection_str" is depricated and will be removed in a future release. ' + 'Please use "selection" instead.') + selection = selection_str + + if selection: + if hasattr(selection, '__call__'): + mask = selection(x) + else: + if isinstance(selection, (list, tuple)): + selection = ' & '.join(f'({x})' for x in selection) + + mask = numexpr.evaluate(selection, local_dict={ + fn: x[fn] + for fn in x.dtype.names}) + x = x[mask] if keep_columns: diff --git a/tests/test_utils.py b/tests/test_utils.py index 52181fcd7..88ae4c05e 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -115,7 +115,29 @@ def test_selection_str(d): mask = (d['data'] > mean_data) & (d['data'] < max_data) selections_str = [f'data > {mean_data}', f'data < {max_data}'] - selected_data = strax.apply_selection(d, selection_str=selections_str) + selected_data = strax.apply_selection(d, selection=selections_str) + assert np.all(selected_data == d[mask]) + + +@settings(deadline=None) +@given(get_dummy_data( + data_length=(1, 10), + dt=(1, 10), + max_time=(1, 20))) +def test_selection_function(d): + """ + Test selection string. We are going for this example check that + selecting the data based on the data field is the same as if we + were to use a mask NB: data must have some length! + + :param d: test-data from get_dummy_data + :return: None + """ + mean_data = np.mean(d['data']) + max_data = np.max(d['data']) + mask = (d['data'] > mean_data) & (d['data'] < max_data) + selections_function = lambda data: (data['data'] > mean_data) & (d['data'] < max_data) + selected_data = strax.apply_selection(d, selection=selections_function) assert np.all(selected_data == d[mask])