Skip to content

Commit

Permalink
Changes how and where the moving threshold is computed
Browse files Browse the repository at this point in the history
  • Loading branch information
cosimolupo committed Mar 19, 2024
1 parent 6c1b2f4 commit 56cf073
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ NUM_INTERPOLATION_POINTS: 0
# minimum distance between two peaks (s)
MIN_PEAK_DISTANCE: 0.28
# amplitude fraction to set the threshold detecting local maxima
MAXIMA_THRESHOLD_FRACTION: .5
# range of allowed values is [0,1], default value is 0.5
MAXIMA_THRESHOLD_FRACTION: 0.5
# time window to use to set the threshold detecting local maxima (s)
MAXIMA_THRESHOLD_WINDOW: 3
# default value 'None' is meant to set the time window equal to the entire signal length
MAXIMA_THRESHOLD_WINDOW: 'None'
# minimum time the signal must be increasing after a minima candidate (s)
MINIMA_PERSISTENCE: 0.16

Expand Down
81 changes: 46 additions & 35 deletions cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
CLI.add_argument("--minima_persistence", nargs='?', type=float, default=0.200,
help="minimum time minima (s)")
CLI.add_argument("--maxima_threshold_fraction", nargs='?', type=float, default=0.5,
help="amplitude fraction to set the threshold detecting local maxima")
CLI.add_argument("--maxima_threshold_window", nargs='?', type=int, default=2,
help="time window to use to set the threshold detecting local maxima [s]")
help="amplitude fraction (in range [0,1]) to set the threshold for detecting local maxima")
CLI.add_argument("--maxima_threshold_window", nargs='?', type=none_or_float, default=None,
help="time window (s) to set the threshold for detecting local maxima")
CLI.add_argument("--min_peak_distance", nargs='?', type=float, default=0.200,
help="minimum distance between peaks (s)")
CLI.add_argument("--img_dir", nargs='?', type=Path,
Expand All @@ -39,9 +39,9 @@
CLI.add_argument("--plot_channels", nargs='+', type=none_or_int, default=None,
help="list of channels to plot")
CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0.,
help="start time in seconds")
help="start time (s)")
CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10.,
help="stop time in seconds")
help="stop time (s)")

def filter_minima_order(signal, mins, order=1):
filtered_mins = np.array([], dtype=int)
Expand All @@ -56,25 +56,40 @@ def filter_minima_order(signal, mins, order=1):
return filtered_mins


def moving_threshold(signal, window, fraction):
# compute a dynamic threshold function through a sliding window
# on the signal array
strides = np.lib.stride_tricks.sliding_window_view(signal, window)
threshold_func = np.min(strides, axis=1) + fraction*np.ptp(strides, axis=1)
def moving_threshold(signal, maxima_threshold_window, maxima_threshold_fraction):
threshold_signal = []

# add elements at the beginning
threshold_func = np.append(np.ones(window//2)*threshold_func[0], threshold_func)
threshold_func = np.append(threshold_func, np.ones(len(signal)-len(threshold_func))*threshold_func[-1])
return threshold_func
sampling_rate = signal.sampling_rate.rescale('Hz').magnitude
duration = float(signal.t_stop.rescale('s').magnitude - signal.t_start.rescale('s').magnitude)
if maxima_threshold_window is None or maxima_threshold_window > duration:
maxima_threshold_window = duration
window_frame = int(maxima_threshold_window*sampling_rate)

for channel, channel_signal in enumerate(signal.T):
if np.isnan(channel_signal).any():
threshold_func = np.full(np.shape(signal)[0], np.nan)
else:
# compute a dynamic threshold function through a sliding window
# on the signal array
strides = np.lib.stride_tricks.sliding_window_view(channel_signal, window_frame)
threshold_func = np.min(strides, axis=1) + maxima_threshold_fraction*np.ptp(strides, axis=1)
# add elements at the beginning
threshold_func = np.append(np.ones(window_frame//2)*threshold_func[0], threshold_func)
threshold_func = np.append(threshold_func, np.ones(len(channel_signal)-len(threshold_func))*threshold_func[-1])
threshold_signal.append(threshold_func)

threshold_signal = neo.AnalogSignal(np.array(threshold_signal).T, units=signal.units, sampling_rate=signal.sampling_rate)

return threshold_signal


def detect_minima(asig, interpolation_points, maxima_threshold_fraction,
maxima_threshold_window, min_peak_distance, minima_persistence):
def detect_minima(asig, threshold_asig, interpolation_points,
min_peak_distance, minima_persistence):
signal = asig.as_array()
times = asig.times.rescale('s').magnitude
sampling_rate = asig.sampling_rate.rescale('Hz').magnitude
window_frame = int(maxima_threshold_window*sampling_rate)
threshold = threshold_asig.as_array()

min_time_idx = np.array([], dtype=int)
channel_idx = np.array([], dtype=int)

Expand All @@ -84,8 +99,7 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,
for channel, channel_signal in enumerate(signal.T):
if np.isnan(channel_signal).any(): continue

threshold_func = moving_threshold(channel_signal, window_frame, maxima_threshold_fraction)
peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold_func)
peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold.T[channel])
dmins, _ = find_peaks(-channel_signal)#, distance=min_distance)

mins = filter_minima_order(channel_signal, dmins, order=minima_order)
Expand Down Expand Up @@ -150,30 +164,28 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,
return evt


def plot_minima(asig, event, channel, maxima_threshold_window,
maxima_threshold_fraction, min_peak_distance):
def plot_minima(asig, event, threshold_asig, channel, min_peak_distance):
signal = asig.as_array().T[channel]
times = asig.times.rescale('s')
sampling_rate = asig.sampling_rate.rescale('Hz').magnitude
window_frame = int(maxima_threshold_window*sampling_rate)
threshold_func = moving_threshold(signal, window_frame, maxima_threshold_fraction)
event = time_slice(event, asig.times[0], asig.times[-1])
threshold = threshold_asig.as_array().T[channel]

peaks, _ = find_peaks(signal, height=threshold_func,
peaks, _ = find_peaks(signal, height=threshold,
distance=np.max([min_peak_distance*sampling_rate, 1]))

# plot figure
sns.set(style='ticks', palette="deep", context="notebook")
fig, ax = plt.subplots()

ax.plot(times, signal, label='signal', color='k')
ax.plot(times, threshold_func, label='moving threshold',
ax.plot(times, threshold, label='moving threshold',
linestyle=':', color='b')

idx_ch = np.where(event.array_annotations['channels'] == channel)[0]

ax.plot(times[peaks], signal[peaks], 'x', color='r', label='detected maxima')
ax.plot(event.times[idx_ch], signal[(event.times[idx_ch]*sampling_rate).astype(int)],
ax.plot(event.times[idx_ch],
signal[((event.times[idx_ch]-asig.times[0])*sampling_rate).astype(int)],
'x', color='g', label='selected minima')

ax.set_title(f'channel {channel}')
Expand All @@ -189,10 +201,10 @@ def plot_minima(asig, event, channel, maxima_threshold_window,
block = load_neo(args.data)
asig = block.segments[0].analogsignals[0]

transition_event = detect_minima(asig,
threshold_asig = moving_threshold(asig, args.maxima_threshold_window, args.maxima_threshold_fraction)

transition_event = detect_minima(asig, threshold_asig,
interpolation_points=args.num_interpolation_points,
maxima_threshold_fraction=args.maxima_threshold_fraction,
maxima_threshold_window=args.maxima_threshold_window,
min_peak_distance=args.min_peak_distance,
minima_persistence=args.minima_persistence)

Expand All @@ -204,9 +216,8 @@ def plot_minima(asig, event, channel, maxima_threshold_window,
for channel in args.plot_channels:
plot_minima(asig=time_slice(asig, args.plot_tstart, args.plot_tstop),
event=time_slice(transition_event, args.plot_tstart, args.plot_tstop),
threshold_asig=time_slice(threshold_asig, args.plot_tstart, args.plot_tstop),
channel=int(channel),
maxima_threshold_window = args.maxima_threshold_window,
maxima_threshold_fraction = args.maxima_threshold_fraction,
min_peak_distance = args.min_peak_distance)
min_peak_distance=args.min_peak_distance)
output_path = args.img_dir / args.img_name.replace('_channel0', f'_channel{channel}')
save_plot(output_path)
save_plot(output_path)

0 comments on commit 56cf073

Please sign in to comment.