Skip to content

Commit

Permalink
Change of default behavior for MAXIMA_THRESHOLD_WINDOW (#49)
Browse files Browse the repository at this point in the history
* Changes how and where the moving threshold is computed

* Minor changes

* Removes unessential line on stage3 config template file
  • Loading branch information
cosimolupo authored Jul 18, 2024
1 parent 65eb0cf commit 491aed3
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ NUM_INTERPOLATION_POINTS: 0
# minimum distance between two peaks (s)
MIN_PEAK_DISTANCE: 0.28
# amplitude fraction to set the threshold detecting local maxima
MAXIMA_THRESHOLD_FRACTION: .5
MAXIMA_THRESHOLD_FRACTION: 0.5
# time window to use to set the threshold detecting local maxima (s)
MAXIMA_THRESHOLD_WINDOW: 3
# default value 'None' is meant to set the time window equal to the entire signal length
MAXIMA_THRESHOLD_WINDOW: 'None'
# minimum time the signal must be increasing after a minima candidate (s)
MINIMA_PERSISTENCE: 0.16

Expand Down
105 changes: 58 additions & 47 deletions cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import quantities as pq
from scipy.signal import find_peaks
import argparse
from pathlib import Path
from utils.io_utils import load_neo, write_neo, save_plot
from utils.neo_utils import remove_annotations, time_slice
from utils.parse import none_or_int, none_or_float, none_or_str
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path


CLI = argparse.ArgumentParser()
Expand All @@ -26,9 +26,9 @@
CLI.add_argument("--minima_persistence", nargs='?', type=float, default=0.200,
help="minimum time minima (s)")
CLI.add_argument("--maxima_threshold_fraction", nargs='?', type=float, default=0.5,
help="amplitude fraction to set the threshold detecting local maxima")
CLI.add_argument("--maxima_threshold_window", nargs='?', type=int, default=2,
help="time window to use to set the threshold detecting local maxima [s]")
help="amplitude fraction (in range [0,1]) to set the threshold for detecting local maxima")
CLI.add_argument("--maxima_threshold_window", nargs='?', type=none_or_float, default=None,
help="time window (s) to set the threshold for detecting local maxima")
CLI.add_argument("--min_peak_distance", nargs='?', type=float, default=0.200,
help="minimum distance between peaks (s)")
CLI.add_argument("--img_dir", nargs='?', type=Path,
Expand All @@ -38,10 +38,10 @@
help='example image filename for channel 0')
CLI.add_argument("--plot_channels", nargs='+', type=none_or_int, default=None,
help="list of channels to plot")
CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0.,
help="start time in seconds")
CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10.,
help="stop time in seconds")
CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0,
help="start time (s)")
CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10,
help="stop time (s)")

def filter_minima_order(signal, mins, order=1):
filtered_mins = np.array([], dtype=int)
Expand All @@ -56,36 +56,50 @@ def filter_minima_order(signal, mins, order=1):
return filtered_mins


def moving_threshold(signal, window, fraction):
# compute a dynamic threshold function through a sliding window
# on the signal array
strides = np.lib.stride_tricks.sliding_window_view(signal, window)
threshold_func = np.min(strides, axis=1) + fraction*np.ptp(strides, axis=1)
def moving_threshold(signal, maxima_threshold_window, maxima_threshold_fraction):
threshold_signal = []

sampling_rate = signal.sampling_rate.rescale('Hz').magnitude
duration = float(signal.t_stop.rescale('s').magnitude - signal.t_start.rescale('s').magnitude)
if maxima_threshold_window is None or maxima_threshold_window > duration:
maxima_threshold_window = duration
window_frame = int(maxima_threshold_window*sampling_rate)

for channel, channel_signal in enumerate(signal.T):
if np.isnan(channel_signal).any():
threshold_func = np.full(np.shape(signal)[0], np.nan)
else:
# compute a dynamic threshold function through a sliding window
# on the signal array
strides = np.lib.stride_tricks.sliding_window_view(channel_signal, window_frame)
threshold_func = np.min(strides, axis=1) + maxima_threshold_fraction*np.ptp(strides, axis=1)
# add elements at the beginning
threshold_func = np.append(np.ones(window_frame//2)*threshold_func[0], threshold_func)
threshold_func = np.append(threshold_func, np.ones(len(channel_signal)-len(threshold_func))*threshold_func[-1])
threshold_signal.append(threshold_func)

threshold_signal = neo.AnalogSignal(np.array(threshold_signal).T, units=signal.units, sampling_rate=signal.sampling_rate)

# add elements at the beginning
threshold_func = np.append(np.ones(window//2)*threshold_func[0], threshold_func)
threshold_func = np.append(threshold_func, np.ones(len(signal)-len(threshold_func))*threshold_func[-1])
return threshold_func
return threshold_signal


def detect_minima(asig, interpolation_points, maxima_threshold_fraction,
maxima_threshold_window, min_peak_distance, minima_persistence):
def detect_minima(asig, threshold_asig, interpolation_points,
min_peak_distance, minima_persistence):
signal = asig.as_array()
times = asig.times.rescale('s').magnitude
sampling_rate = asig.sampling_rate.rescale('Hz').magnitude
window_frame = int(maxima_threshold_window*sampling_rate)
threshold = threshold_asig.as_array()

min_time_idx = np.array([], dtype=int)
channel_idx = np.array([], dtype=int)

minima_order = int(np.max([minima_persistence*sampling_rate, 1]))
min_distance = np.max([min_peak_distance*sampling_rate, 1])

for channel, channel_signal in enumerate(signal.T):
if np.isnan(channel_signal).any(): continue

threshold_func = moving_threshold(channel_signal, window_frame, maxima_threshold_fraction)
peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold_func)
peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold.T[channel])
dmins, _ = find_peaks(-channel_signal)#, distance=min_distance)

mins = filter_minima_order(channel_signal, dmins, order=minima_order)
Expand All @@ -101,7 +115,7 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,

min_time_idx = np.append(min_time_idx, clean_mins)
channel_idx = np.append(channel_idx, np.ones(len(clean_mins), dtype=int)*channel)

# compute local minima times.
if interpolation_points:
# parabolic fit on the right branch of local minima
Expand All @@ -126,12 +140,12 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,
minimum_times = min_pos/sampling_rate*pq.s
else:
minimum_times = asig.times[min_time_idx]

idx = np.where(minimum_times >= np.max(asig.times))[0]
minimum_times[idx] = np.max(asig.times)
###################################
sort_idx = np.argsort(minimum_times)

# save detected minima as transition
evt = neo.Event(times=minimum_times[sort_idx],
labels=['UP'] * len(minimum_times),
Expand All @@ -146,40 +160,38 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,

remove_annotations(asig)
evt.annotations.update(asig.annotations)

return evt


def plot_minima(asig, event, channel, maxima_threshold_window,
maxima_threshold_fraction, min_peak_distance):
def plot_minima(asig, event, threshold_asig, channel, min_peak_distance):
signal = asig.as_array().T[channel]
times = asig.times.rescale('s')
sampling_rate = asig.sampling_rate.rescale('Hz').magnitude
window_frame = int(maxima_threshold_window*sampling_rate)
threshold_func = moving_threshold(signal, window_frame, maxima_threshold_fraction)
event = time_slice(event, asig.times[0], asig.times[-1])
threshold = threshold_asig.as_array().T[channel]

peaks, _ = find_peaks(signal, height=threshold_func,
peaks, _ = find_peaks(signal, height=threshold,
distance=np.max([min_peak_distance*sampling_rate, 1]))

# plot figure
sns.set(style='ticks', palette="deep", context="notebook")
fig, ax = plt.subplots()

ax.plot(times, signal, label='signal', color='k')
ax.plot(times, threshold_func, label='moving threshold',
ax.plot(times, threshold, label='moving threshold',
linestyle=':', color='b')

idx_ch = np.where(event.array_annotations['channels'] == channel)[0]

ax.plot(times[peaks], signal[peaks], 'x', color='r', label='detected maxima')
ax.plot(event.times[idx_ch], signal[(event.times[idx_ch]*sampling_rate).astype(int)],
ax.plot(event.times[idx_ch],
signal[((event.times[idx_ch]-asig.times[0])*sampling_rate).astype(int)],
'x', color='g', label='selected minima')

ax.set_title(f'channel {channel}')
ax.set_xlabel('time [s]')
ax.legend()

return ax


Expand All @@ -189,13 +201,13 @@ def plot_minima(asig, event, channel, maxima_threshold_window,
block = load_neo(args.data)
asig = block.segments[0].analogsignals[0]

transition_event = detect_minima(asig,
threshold_asig = moving_threshold(asig, args.maxima_threshold_window, args.maxima_threshold_fraction)

transition_event = detect_minima(asig, threshold_asig,
interpolation_points=args.num_interpolation_points,
maxima_threshold_fraction=args.maxima_threshold_fraction,
maxima_threshold_window=args.maxima_threshold_window,
min_peak_distance=args.min_peak_distance,
minima_persistence=args.minima_persistence)

block.segments[0].events.append(transition_event)

write_neo(args.output, block)
Expand All @@ -204,9 +216,8 @@ def plot_minima(asig, event, channel, maxima_threshold_window,
for channel in args.plot_channels:
plot_minima(asig=time_slice(asig, args.plot_tstart, args.plot_tstop),
event=time_slice(transition_event, args.plot_tstart, args.plot_tstop),
threshold_asig=time_slice(threshold_asig, args.plot_tstart, args.plot_tstop),
channel=int(channel),
maxima_threshold_window = args.maxima_threshold_window,
maxima_threshold_fraction = args.maxima_threshold_fraction,
min_peak_distance = args.min_peak_distance)
min_peak_distance=args.min_peak_distance)
output_path = args.img_dir / args.img_name.replace('_channel0', f'_channel{channel}')
save_plot(output_path)

0 comments on commit 491aed3

Please sign in to comment.