Skip to content

Commit

Permalink
Add small selection functions (#746)
Browse files Browse the repository at this point in the history
  • Loading branch information
WenzDaniel authored Aug 11, 2023
1 parent 9b508f7 commit c1d0554
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 9 deletions.
7 changes: 6 additions & 1 deletion strax/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,6 +1205,7 @@ def get_iter(self, run_id: str,
seconds_range=None,
time_within=None,
time_selection='fully_contained',
selection=None,
selection_str=None,
keep_columns=None,
drop_columns=None,
Expand Down Expand Up @@ -1306,6 +1307,7 @@ def get_iter(self, run_id: str,
result.data = strax.apply_selection(
result.data,
selection_str=selection_str,
selection=selection,
keep_columns=keep_columns,
drop_columns=drop_columns,
time_range=time_range,
Expand Down Expand Up @@ -2110,7 +2112,10 @@ def add_method(cls, f):


select_docs = """
:param selection_str: Query string or sequence of strings to apply.
:param selection: Query string, sequence of strings, or simple function to apply.
The function must take a single argument which represents the structure
numpy array of the loaded data.
:param selection_str: Same as selection (deprecated)
:param keep_columns: Array field/dataframe column names to keep.
Useful to reduce amount of data in memory. (You can only specify
either keep or drop column.)
Expand Down
25 changes: 18 additions & 7 deletions strax/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,7 @@ def iter_chunk_meta(md):

@export
def apply_selection(x,
selection=None,
selection_str=None,
keep_columns=None,
drop_columns=None,
Expand All @@ -619,7 +620,8 @@ def apply_selection(x,
"""Return x after applying selections
:param x: Numpy structured array
:param selection_str: Query string or sequence of strings to apply.
:param selection: Query string, sequence of strings, or simple function to apply.
:param selection_str: Same as selection (deprecated)
:param time_range: (start, stop) range to load, in ns since the epoch
:param keep_columns: Field names of the columns to keep.
:param drop_columns: Field names of the columns to drop.
Expand Down Expand Up @@ -649,12 +651,21 @@ def apply_selection(x,
raise ValueError(f"Unknown time_selection {time_selection}")

if selection_str:
if isinstance(selection_str, (list, tuple)):
selection_str = ' & '.join(f'({x})' for x in selection_str)

mask = numexpr.evaluate(selection_str, local_dict={
fn: x[fn]
for fn in x.dtype.names})
warn('The option "selection_str" is depricated and will be removed in a future release. '
'Please use "selection" instead.')
selection = selection_str

if selection:
if hasattr(selection, '__call__'):
mask = selection(x)
else:
if isinstance(selection, (list, tuple)):
selection = ' & '.join(f'({x})' for x in selection)

mask = numexpr.evaluate(selection, local_dict={
fn: x[fn]
for fn in x.dtype.names})

x = x[mask]

if keep_columns:
Expand Down
24 changes: 23 additions & 1 deletion tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,29 @@ def test_selection_str(d):
mask = (d['data'] > mean_data) & (d['data'] < max_data)
selections_str = [f'data > {mean_data}',
f'data < {max_data}']
selected_data = strax.apply_selection(d, selection_str=selections_str)
selected_data = strax.apply_selection(d, selection=selections_str)
assert np.all(selected_data == d[mask])


@settings(deadline=None)
@given(get_dummy_data(
data_length=(1, 10),
dt=(1, 10),
max_time=(1, 20)))
def test_selection_function(d):
"""
Test selection string. We are going for this example check that
selecting the data based on the data field is the same as if we
were to use a mask NB: data must have some length!
:param d: test-data from get_dummy_data
:return: None
"""
mean_data = np.mean(d['data'])
max_data = np.max(d['data'])
mask = (d['data'] > mean_data) & (d['data'] < max_data)
selections_function = lambda data: (data['data'] > mean_data) & (d['data'] < max_data)
selected_data = strax.apply_selection(d, selection=selections_function)
assert np.all(selected_data == d[mask])


Expand Down

0 comments on commit c1d0554

Please sign in to comment.