From 56cf07371f374e94f2ac9f78150aa3f297ee7aaa Mon Sep 17 00:00:00 2001 From: cosimolupo Date: Tue, 19 Mar 2024 11:49:08 +0100 Subject: [PATCH] Changes how and where the moving threshold is computed --- .../configs/config_template.yaml | 6 +- .../scripts/minima.py | 81 +++++++++++-------- 2 files changed, 50 insertions(+), 37 deletions(-) diff --git a/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml b/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml index 20d8fefe..629541d8 100644 --- a/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml +++ b/cobrawap/pipeline/stage03_trigger_detection/configs/config_template.yaml @@ -64,9 +64,11 @@ NUM_INTERPOLATION_POINTS: 0 # minimum distance between two peaks (s) MIN_PEAK_DISTANCE: 0.28 # amplitude fraction to set the threshold detecting local maxima -MAXIMA_THRESHOLD_FRACTION: .5 +# range of allowed values is [0,1], default value is 0.5 +MAXIMA_THRESHOLD_FRACTION: 0.5 # time window to use to set the threshold detecting local maxima (s) -MAXIMA_THRESHOLD_WINDOW: 3 +# default value 'None' is meant to set the time window equal to the entire signal length +MAXIMA_THRESHOLD_WINDOW: 'None' # minimum time the signal must be increasing after a minima candidate (s) MINIMA_PERSISTENCE: 0.16 diff --git a/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py b/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py index ee16fc47..69c2e6ba 100644 --- a/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py +++ b/cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py @@ -26,9 +26,9 @@ CLI.add_argument("--minima_persistence", nargs='?', type=float, default=0.200, help="minimum time minima (s)") CLI.add_argument("--maxima_threshold_fraction", nargs='?', type=float, default=0.5, - help="amplitude fraction to set the threshold detecting local maxima") -CLI.add_argument("--maxima_threshold_window", nargs='?', type=int, default=2, - help="time window to use to set the threshold detecting local maxima [s]") + help="amplitude fraction (in range [0,1]) to set the threshold for detecting local maxima") +CLI.add_argument("--maxima_threshold_window", nargs='?', type=none_or_float, default=None, + help="time window (s) to set the threshold for detecting local maxima") CLI.add_argument("--min_peak_distance", nargs='?', type=float, default=0.200, help="minimum distance between peaks (s)") CLI.add_argument("--img_dir", nargs='?', type=Path, @@ -39,9 +39,9 @@ CLI.add_argument("--plot_channels", nargs='+', type=none_or_int, default=None, help="list of channels to plot") CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0., - help="start time in seconds") + help="start time (s)") CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10., - help="stop time in seconds") + help="stop time (s)") def filter_minima_order(signal, mins, order=1): filtered_mins = np.array([], dtype=int) @@ -56,25 +56,40 @@ def filter_minima_order(signal, mins, order=1): return filtered_mins -def moving_threshold(signal, window, fraction): - # compute a dynamic threshold function through a sliding window - # on the signal array - strides = np.lib.stride_tricks.sliding_window_view(signal, window) - threshold_func = np.min(strides, axis=1) + fraction*np.ptp(strides, axis=1) +def moving_threshold(signal, maxima_threshold_window, maxima_threshold_fraction): + threshold_signal = [] - # add elements at the beginning - threshold_func = np.append(np.ones(window//2)*threshold_func[0], threshold_func) - threshold_func = np.append(threshold_func, np.ones(len(signal)-len(threshold_func))*threshold_func[-1]) - return threshold_func + sampling_rate = signal.sampling_rate.rescale('Hz').magnitude + duration = float(signal.t_stop.rescale('s').magnitude - signal.t_start.rescale('s').magnitude) + if maxima_threshold_window is None or maxima_threshold_window > duration: + maxima_threshold_window = duration + window_frame = int(maxima_threshold_window*sampling_rate) + + for channel, channel_signal in enumerate(signal.T): + if np.isnan(channel_signal).any(): + threshold_func = np.full(np.shape(signal)[0], np.nan) + else: + # compute a dynamic threshold function through a sliding window + # on the signal array + strides = np.lib.stride_tricks.sliding_window_view(channel_signal, window_frame) + threshold_func = np.min(strides, axis=1) + maxima_threshold_fraction*np.ptp(strides, axis=1) + # add elements at the beginning + threshold_func = np.append(np.ones(window_frame//2)*threshold_func[0], threshold_func) + threshold_func = np.append(threshold_func, np.ones(len(channel_signal)-len(threshold_func))*threshold_func[-1]) + threshold_signal.append(threshold_func) + + threshold_signal = neo.AnalogSignal(np.array(threshold_signal).T, units=signal.units, sampling_rate=signal.sampling_rate) + + return threshold_signal -def detect_minima(asig, interpolation_points, maxima_threshold_fraction, - maxima_threshold_window, min_peak_distance, minima_persistence): +def detect_minima(asig, threshold_asig, interpolation_points, + min_peak_distance, minima_persistence): signal = asig.as_array() times = asig.times.rescale('s').magnitude sampling_rate = asig.sampling_rate.rescale('Hz').magnitude - window_frame = int(maxima_threshold_window*sampling_rate) - + threshold = threshold_asig.as_array() + min_time_idx = np.array([], dtype=int) channel_idx = np.array([], dtype=int) @@ -84,8 +99,7 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction, for channel, channel_signal in enumerate(signal.T): if np.isnan(channel_signal).any(): continue - threshold_func = moving_threshold(channel_signal, window_frame, maxima_threshold_fraction) - peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold_func) + peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold.T[channel]) dmins, _ = find_peaks(-channel_signal)#, distance=min_distance) mins = filter_minima_order(channel_signal, dmins, order=minima_order) @@ -150,16 +164,13 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction, return evt -def plot_minima(asig, event, channel, maxima_threshold_window, - maxima_threshold_fraction, min_peak_distance): +def plot_minima(asig, event, threshold_asig, channel, min_peak_distance): signal = asig.as_array().T[channel] times = asig.times.rescale('s') sampling_rate = asig.sampling_rate.rescale('Hz').magnitude - window_frame = int(maxima_threshold_window*sampling_rate) - threshold_func = moving_threshold(signal, window_frame, maxima_threshold_fraction) - event = time_slice(event, asig.times[0], asig.times[-1]) + threshold = threshold_asig.as_array().T[channel] - peaks, _ = find_peaks(signal, height=threshold_func, + peaks, _ = find_peaks(signal, height=threshold, distance=np.max([min_peak_distance*sampling_rate, 1])) # plot figure @@ -167,13 +178,14 @@ def plot_minima(asig, event, channel, maxima_threshold_window, fig, ax = plt.subplots() ax.plot(times, signal, label='signal', color='k') - ax.plot(times, threshold_func, label='moving threshold', + ax.plot(times, threshold, label='moving threshold', linestyle=':', color='b') idx_ch = np.where(event.array_annotations['channels'] == channel)[0] ax.plot(times[peaks], signal[peaks], 'x', color='r', label='detected maxima') - ax.plot(event.times[idx_ch], signal[(event.times[idx_ch]*sampling_rate).astype(int)], + ax.plot(event.times[idx_ch], + signal[((event.times[idx_ch]-asig.times[0])*sampling_rate).astype(int)], 'x', color='g', label='selected minima') ax.set_title(f'channel {channel}') @@ -189,10 +201,10 @@ def plot_minima(asig, event, channel, maxima_threshold_window, block = load_neo(args.data) asig = block.segments[0].analogsignals[0] - transition_event = detect_minima(asig, + threshold_asig = moving_threshold(asig, args.maxima_threshold_window, args.maxima_threshold_fraction) + + transition_event = detect_minima(asig, threshold_asig, interpolation_points=args.num_interpolation_points, - maxima_threshold_fraction=args.maxima_threshold_fraction, - maxima_threshold_window=args.maxima_threshold_window, min_peak_distance=args.min_peak_distance, minima_persistence=args.minima_persistence) @@ -204,9 +216,8 @@ def plot_minima(asig, event, channel, maxima_threshold_window, for channel in args.plot_channels: plot_minima(asig=time_slice(asig, args.plot_tstart, args.plot_tstop), event=time_slice(transition_event, args.plot_tstart, args.plot_tstop), + threshold_asig=time_slice(threshold_asig, args.plot_tstart, args.plot_tstop), channel=int(channel), - maxima_threshold_window = args.maxima_threshold_window, - maxima_threshold_fraction = args.maxima_threshold_fraction, - min_peak_distance = args.min_peak_distance) + min_peak_distance=args.min_peak_distance) output_path = args.img_dir / args.img_name.replace('_channel0', f'_channel{channel}') - save_plot(output_path) \ No newline at end of file + save_plot(output_path)