From 491aed3ab4bc4b61d69abd5022690773ce972aa6 Mon Sep 17 00:00:00 2001 From: Cosimo Lupo <36234686+cosimolupo@users.noreply.github.com> Date: Thu, 18 Jul 2024 15:22:50 +0200 Subject: [PATCH] Change of default behavior for MAXIMA_THRESHOLD_WINDOW (#49) * Changes how and where the moving threshold is computed * Minor changes * Removes unessential line on stage3 config template file --- .../configs/config_template.yaml | 5 +- .../scripts/minima.py | 105 ++++++++++-------- 2 files changed, 61 insertions(+), 49 deletions(-) diff --git a/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml b/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml index 6c55e5b3..35f23b01 100644 --- a/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml +++ b/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml @@ -64,9 +64,10 @@ NUM_INTERPOLATION_POINTS: 0 # minimum distance between two peaks (s) MIN_PEAK_DISTANCE: 0.28 # amplitude fraction to set the threshold detecting local maxima -MAXIMA_THRESHOLD_FRACTION: .5 +MAXIMA_THRESHOLD_FRACTION: 0.5 # time window to use to set the threshold detecting local maxima (s) -MAXIMA_THRESHOLD_WINDOW: 3 +# default value 'None' is meant to set the time window equal to the entire signal length +MAXIMA_THRESHOLD_WINDOW: 'None' # minimum time the signal must be increasing after a minima candidate (s) MINIMA_PERSISTENCE: 0.16 diff --git a/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py b/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py index c959a503..a8672c6b 100644 --- a/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py +++ b/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py @@ -8,12 +8,12 @@ import quantities as pq from scipy.signal import find_peaks import argparse +from pathlib import Path from utils.io_utils import load_neo, write_neo, save_plot from utils.neo_utils import remove_annotations, time_slice from utils.parse import none_or_int, none_or_float, none_or_str import seaborn as sns import matplotlib.pyplot as plt -from pathlib import Path CLI = argparse.ArgumentParser() @@ -26,9 +26,9 @@ CLI.add_argument("--minima_persistence", nargs='?', type=float, default=0.200, help="minimum time minima (s)") CLI.add_argument("--maxima_threshold_fraction", nargs='?', type=float, default=0.5, - help="amplitude fraction to set the threshold detecting local maxima") -CLI.add_argument("--maxima_threshold_window", nargs='?', type=int, default=2, - help="time window to use to set the threshold detecting local maxima [s]") + help="amplitude fraction (in range [0,1]) to set the threshold for detecting local maxima") +CLI.add_argument("--maxima_threshold_window", nargs='?', type=none_or_float, default=None, + help="time window (s) to set the threshold for detecting local maxima") CLI.add_argument("--min_peak_distance", nargs='?', type=float, default=0.200, help="minimum distance between peaks (s)") CLI.add_argument("--img_dir", nargs='?', type=Path, @@ -38,10 +38,10 @@ help='example image filename for channel 0') CLI.add_argument("--plot_channels", nargs='+', type=none_or_int, default=None, help="list of channels to plot") -CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0., - help="start time in seconds") -CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10., - help="stop time in seconds") +CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0, + help="start time (s)") +CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10, + help="stop time (s)") def filter_minima_order(signal, mins, order=1): filtered_mins = np.array([], dtype=int) @@ -56,36 +56,50 @@ def filter_minima_order(signal, mins, order=1): return filtered_mins -def moving_threshold(signal, window, fraction): - # compute a dynamic threshold function through a sliding window - # on the signal array - strides = np.lib.stride_tricks.sliding_window_view(signal, window) - threshold_func = np.min(strides, axis=1) + fraction*np.ptp(strides, axis=1) +def moving_threshold(signal, maxima_threshold_window, maxima_threshold_fraction): + threshold_signal = [] + + sampling_rate = signal.sampling_rate.rescale('Hz').magnitude + duration = float(signal.t_stop.rescale('s').magnitude - signal.t_start.rescale('s').magnitude) + if maxima_threshold_window is None or maxima_threshold_window > duration: + maxima_threshold_window = duration + window_frame = int(maxima_threshold_window*sampling_rate) + + for channel, channel_signal in enumerate(signal.T): + if np.isnan(channel_signal).any(): + threshold_func = np.full(np.shape(signal)[0], np.nan) + else: + # compute a dynamic threshold function through a sliding window + # on the signal array + strides = np.lib.stride_tricks.sliding_window_view(channel_signal, window_frame) + threshold_func = np.min(strides, axis=1) + maxima_threshold_fraction*np.ptp(strides, axis=1) + # add elements at the beginning + threshold_func = np.append(np.ones(window_frame//2)*threshold_func[0], threshold_func) + threshold_func = np.append(threshold_func, np.ones(len(channel_signal)-len(threshold_func))*threshold_func[-1]) + threshold_signal.append(threshold_func) + + threshold_signal = neo.AnalogSignal(np.array(threshold_signal).T, units=signal.units, sampling_rate=signal.sampling_rate) - # add elements at the beginning - threshold_func = np.append(np.ones(window//2)*threshold_func[0], threshold_func) - threshold_func = np.append(threshold_func, np.ones(len(signal)-len(threshold_func))*threshold_func[-1]) - return threshold_func + return threshold_signal -def detect_minima(asig, interpolation_points, maxima_threshold_fraction, - maxima_threshold_window, min_peak_distance, minima_persistence): +def detect_minima(asig, threshold_asig, interpolation_points, + min_peak_distance, minima_persistence): signal = asig.as_array() times = asig.times.rescale('s').magnitude sampling_rate = asig.sampling_rate.rescale('Hz').magnitude - window_frame = int(maxima_threshold_window*sampling_rate) - + threshold = threshold_asig.as_array() + min_time_idx = np.array([], dtype=int) channel_idx = np.array([], dtype=int) minima_order = int(np.max([minima_persistence*sampling_rate, 1])) min_distance = np.max([min_peak_distance*sampling_rate, 1]) - + for channel, channel_signal in enumerate(signal.T): if np.isnan(channel_signal).any(): continue - threshold_func = moving_threshold(channel_signal, window_frame, maxima_threshold_fraction) - peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold_func) + peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold.T[channel]) dmins, _ = find_peaks(-channel_signal)#, distance=min_distance) mins = filter_minima_order(channel_signal, dmins, order=minima_order) @@ -101,7 +115,7 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction, min_time_idx = np.append(min_time_idx, clean_mins) channel_idx = np.append(channel_idx, np.ones(len(clean_mins), dtype=int)*channel) - + # compute local minima times. if interpolation_points: # parabolic fit on the right branch of local minima @@ -126,12 +140,12 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction, minimum_times = min_pos/sampling_rate*pq.s else: minimum_times = asig.times[min_time_idx] - + idx = np.where(minimum_times >= np.max(asig.times))[0] minimum_times[idx] = np.max(asig.times) ################################### sort_idx = np.argsort(minimum_times) - + # save detected minima as transition evt = neo.Event(times=minimum_times[sort_idx], labels=['UP'] * len(minimum_times), @@ -146,40 +160,38 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction, remove_annotations(asig) evt.annotations.update(asig.annotations) - + return evt -def plot_minima(asig, event, channel, maxima_threshold_window, - maxima_threshold_fraction, min_peak_distance): +def plot_minima(asig, event, threshold_asig, channel, min_peak_distance): signal = asig.as_array().T[channel] times = asig.times.rescale('s') sampling_rate = asig.sampling_rate.rescale('Hz').magnitude - window_frame = int(maxima_threshold_window*sampling_rate) - threshold_func = moving_threshold(signal, window_frame, maxima_threshold_fraction) - event = time_slice(event, asig.times[0], asig.times[-1]) + threshold = threshold_asig.as_array().T[channel] - peaks, _ = find_peaks(signal, height=threshold_func, + peaks, _ = find_peaks(signal, height=threshold, distance=np.max([min_peak_distance*sampling_rate, 1])) - + # plot figure sns.set(style='ticks', palette="deep", context="notebook") fig, ax = plt.subplots() - + ax.plot(times, signal, label='signal', color='k') - ax.plot(times, threshold_func, label='moving threshold', + ax.plot(times, threshold, label='moving threshold', linestyle=':', color='b') idx_ch = np.where(event.array_annotations['channels'] == channel)[0] - + ax.plot(times[peaks], signal[peaks], 'x', color='r', label='detected maxima') - ax.plot(event.times[idx_ch], signal[(event.times[idx_ch]*sampling_rate).astype(int)], + ax.plot(event.times[idx_ch], + signal[((event.times[idx_ch]-asig.times[0])*sampling_rate).astype(int)], 'x', color='g', label='selected minima') ax.set_title(f'channel {channel}') ax.set_xlabel('time [s]') ax.legend() - + return ax @@ -189,13 +201,13 @@ def plot_minima(asig, event, channel, maxima_threshold_window, block = load_neo(args.data) asig = block.segments[0].analogsignals[0] - transition_event = detect_minima(asig, + threshold_asig = moving_threshold(asig, args.maxima_threshold_window, args.maxima_threshold_fraction) + + transition_event = detect_minima(asig, threshold_asig, interpolation_points=args.num_interpolation_points, - maxima_threshold_fraction=args.maxima_threshold_fraction, - maxima_threshold_window=args.maxima_threshold_window, min_peak_distance=args.min_peak_distance, minima_persistence=args.minima_persistence) - + block.segments[0].events.append(transition_event) write_neo(args.output, block) @@ -204,9 +216,8 @@ def plot_minima(asig, event, channel, maxima_threshold_window, for channel in args.plot_channels: plot_minima(asig=time_slice(asig, args.plot_tstart, args.plot_tstop), event=time_slice(transition_event, args.plot_tstart, args.plot_tstop), + threshold_asig=time_slice(threshold_asig, args.plot_tstart, args.plot_tstop), channel=int(channel), - maxima_threshold_window = args.maxima_threshold_window, - maxima_threshold_fraction = args.maxima_threshold_fraction, - min_peak_distance = args.min_peak_distance) + min_peak_distance=args.min_peak_distance) output_path = args.img_dir / args.img_name.replace('_channel0', f'_channel{channel}') save_plot(output_path)