Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change of default behavior for MAXIMA_THRESHOLD_WINDOW #49

Merged
merged 4 commits into from
Jul 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,10 @@ NUM_INTERPOLATION_POINTS: 0
# minimum distance between two peaks (s)
MIN_PEAK_DISTANCE: 0.28
# amplitude fraction to set the threshold detecting local maxima
MAXIMA_THRESHOLD_FRACTION: .5
MAXIMA_THRESHOLD_FRACTION: 0.5
# time window to use to set the threshold detecting local maxima (s)
MAXIMA_THRESHOLD_WINDOW: 3
# default value 'None' is meant to set the time window equal to the entire signal length
MAXIMA_THRESHOLD_WINDOW: 'None'
# minimum time the signal must be increasing after a minima candidate (s)
MINIMA_PERSISTENCE: 0.16

Expand Down
105 changes: 58 additions & 47 deletions cobrawap/pipeline/stage03_trigger_detection/scripts/minima.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,12 @@
import quantities as pq
from scipy.signal import find_peaks
import argparse
from pathlib import Path
from utils.io_utils import load_neo, write_neo, save_plot
from utils.neo_utils import remove_annotations, time_slice
from utils.parse import none_or_int, none_or_float, none_or_str
import seaborn as sns
import matplotlib.pyplot as plt
from pathlib import Path


CLI = argparse.ArgumentParser()
Expand All @@ -26,9 +26,9 @@
CLI.add_argument("--minima_persistence", nargs='?', type=float, default=0.200,
help="minimum time minima (s)")
CLI.add_argument("--maxima_threshold_fraction", nargs='?', type=float, default=0.5,
help="amplitude fraction to set the threshold detecting local maxima")
CLI.add_argument("--maxima_threshold_window", nargs='?', type=int, default=2,
help="time window to use to set the threshold detecting local maxima [s]")
help="amplitude fraction (in range [0,1]) to set the threshold for detecting local maxima")
CLI.add_argument("--maxima_threshold_window", nargs='?', type=none_or_float, default=None,
help="time window (s) to set the threshold for detecting local maxima")
CLI.add_argument("--min_peak_distance", nargs='?', type=float, default=0.200,
help="minimum distance between peaks (s)")
CLI.add_argument("--img_dir", nargs='?', type=Path,
Expand All @@ -38,10 +38,10 @@
help='example image filename for channel 0')
CLI.add_argument("--plot_channels", nargs='+', type=none_or_int, default=None,
help="list of channels to plot")
CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0.,
help="start time in seconds")
CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10.,
help="stop time in seconds")
CLI.add_argument("--plot_tstart", nargs='?', type=none_or_float, default=0,
help="start time (s)")
CLI.add_argument("--plot_tstop", nargs='?', type=none_or_float, default=10,
help="stop time (s)")

def filter_minima_order(signal, mins, order=1):
filtered_mins = np.array([], dtype=int)
Expand All @@ -56,36 +56,50 @@ def filter_minima_order(signal, mins, order=1):
return filtered_mins


def moving_threshold(signal, window, fraction):
# compute a dynamic threshold function through a sliding window
# on the signal array
strides = np.lib.stride_tricks.sliding_window_view(signal, window)
threshold_func = np.min(strides, axis=1) + fraction*np.ptp(strides, axis=1)
def moving_threshold(signal, maxima_threshold_window, maxima_threshold_fraction):
threshold_signal = []

sampling_rate = signal.sampling_rate.rescale('Hz').magnitude
duration = float(signal.t_stop.rescale('s').magnitude - signal.t_start.rescale('s').magnitude)
if maxima_threshold_window is None or maxima_threshold_window > duration:
maxima_threshold_window = duration
window_frame = int(maxima_threshold_window*sampling_rate)

for channel, channel_signal in enumerate(signal.T):
if np.isnan(channel_signal).any():
threshold_func = np.full(np.shape(signal)[0], np.nan)
else:
# compute a dynamic threshold function through a sliding window
# on the signal array
strides = np.lib.stride_tricks.sliding_window_view(channel_signal, window_frame)
threshold_func = np.min(strides, axis=1) + maxima_threshold_fraction*np.ptp(strides, axis=1)
# add elements at the beginning
threshold_func = np.append(np.ones(window_frame//2)*threshold_func[0], threshold_func)
threshold_func = np.append(threshold_func, np.ones(len(channel_signal)-len(threshold_func))*threshold_func[-1])
threshold_signal.append(threshold_func)

threshold_signal = neo.AnalogSignal(np.array(threshold_signal).T, units=signal.units, sampling_rate=signal.sampling_rate)
Comment on lines -59 to +81
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why create a separate threshold analogsignal here? It seems it requires more code and more RAM

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous version of the code, the threshold was computed channel-wise in the detect_minima function, and then again in plot_minima function for each channel to be plotted. The new implementation does it only once for each channel.
Also, for a given channel to be visualized, the signal to be plotted was trimmed as [plot_tstart, plot_tstop] before passing it to plot_minima function. This implied that the moving threshold had to be computed on the signal after the trimming, with minima found and plotted consequently. This was different from what done in detect_minima function, where moving thresholds and minima were computed on the whole signal, with no trimming. This means that minima reported in plots could have been different from minima actually detected for a certain channel. Thanks to @mdenker for pointing this out during one of our discussions.
So, we decided to coherently compute moving threshold on the whole signal for each channel, only once, and then use such threshold_signal for coming minima detection and visualization. On one side, this certainly implies a larger RAM usage, but ensures a more coherent and robust computation of the moving thresholds. At the same time, larger number of code lines in moving_threshold function are almost balanced by code lines now useless and hence removed elsewhere in the code.


# add elements at the beginning
threshold_func = np.append(np.ones(window//2)*threshold_func[0], threshold_func)
threshold_func = np.append(threshold_func, np.ones(len(signal)-len(threshold_func))*threshold_func[-1])
return threshold_func
return threshold_signal


def detect_minima(asig, interpolation_points, maxima_threshold_fraction,
maxima_threshold_window, min_peak_distance, minima_persistence):
def detect_minima(asig, threshold_asig, interpolation_points,
min_peak_distance, minima_persistence):
signal = asig.as_array()
times = asig.times.rescale('s').magnitude
sampling_rate = asig.sampling_rate.rescale('Hz').magnitude
window_frame = int(maxima_threshold_window*sampling_rate)
threshold = threshold_asig.as_array()

min_time_idx = np.array([], dtype=int)
channel_idx = np.array([], dtype=int)

minima_order = int(np.max([minima_persistence*sampling_rate, 1]))
min_distance = np.max([min_peak_distance*sampling_rate, 1])

for channel, channel_signal in enumerate(signal.T):
if np.isnan(channel_signal).any(): continue

threshold_func = moving_threshold(channel_signal, window_frame, maxima_threshold_fraction)
peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold_func)
peaks, _ = find_peaks(channel_signal, distance=min_distance, height=threshold.T[channel])
dmins, _ = find_peaks(-channel_signal)#, distance=min_distance)

mins = filter_minima_order(channel_signal, dmins, order=minima_order)
Expand All @@ -101,7 +115,7 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,

min_time_idx = np.append(min_time_idx, clean_mins)
channel_idx = np.append(channel_idx, np.ones(len(clean_mins), dtype=int)*channel)

# compute local minima times.
if interpolation_points:
# parabolic fit on the right branch of local minima
Expand All @@ -126,12 +140,12 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,
minimum_times = min_pos/sampling_rate*pq.s
else:
minimum_times = asig.times[min_time_idx]

idx = np.where(minimum_times >= np.max(asig.times))[0]
minimum_times[idx] = np.max(asig.times)
###################################
sort_idx = np.argsort(minimum_times)

# save detected minima as transition
evt = neo.Event(times=minimum_times[sort_idx],
labels=['UP'] * len(minimum_times),
Expand All @@ -146,40 +160,38 @@ def detect_minima(asig, interpolation_points, maxima_threshold_fraction,

remove_annotations(asig)
evt.annotations.update(asig.annotations)

return evt


def plot_minima(asig, event, channel, maxima_threshold_window,
maxima_threshold_fraction, min_peak_distance):
def plot_minima(asig, event, threshold_asig, channel, min_peak_distance):
signal = asig.as_array().T[channel]
times = asig.times.rescale('s')
sampling_rate = asig.sampling_rate.rescale('Hz').magnitude
window_frame = int(maxima_threshold_window*sampling_rate)
threshold_func = moving_threshold(signal, window_frame, maxima_threshold_fraction)
event = time_slice(event, asig.times[0], asig.times[-1])
threshold = threshold_asig.as_array().T[channel]

peaks, _ = find_peaks(signal, height=threshold_func,
peaks, _ = find_peaks(signal, height=threshold,
distance=np.max([min_peak_distance*sampling_rate, 1]))

# plot figure
sns.set(style='ticks', palette="deep", context="notebook")
fig, ax = plt.subplots()

ax.plot(times, signal, label='signal', color='k')
ax.plot(times, threshold_func, label='moving threshold',
ax.plot(times, threshold, label='moving threshold',
linestyle=':', color='b')

idx_ch = np.where(event.array_annotations['channels'] == channel)[0]

ax.plot(times[peaks], signal[peaks], 'x', color='r', label='detected maxima')
ax.plot(event.times[idx_ch], signal[(event.times[idx_ch]*sampling_rate).astype(int)],
ax.plot(event.times[idx_ch],
signal[((event.times[idx_ch]-asig.times[0])*sampling_rate).astype(int)],
'x', color='g', label='selected minima')

ax.set_title(f'channel {channel}')
ax.set_xlabel('time [s]')
ax.legend()

return ax


Expand All @@ -189,13 +201,13 @@ def plot_minima(asig, event, channel, maxima_threshold_window,
block = load_neo(args.data)
asig = block.segments[0].analogsignals[0]

transition_event = detect_minima(asig,
threshold_asig = moving_threshold(asig, args.maxima_threshold_window, args.maxima_threshold_fraction)

transition_event = detect_minima(asig, threshold_asig,
interpolation_points=args.num_interpolation_points,
maxima_threshold_fraction=args.maxima_threshold_fraction,
maxima_threshold_window=args.maxima_threshold_window,
min_peak_distance=args.min_peak_distance,
minima_persistence=args.minima_persistence)

block.segments[0].events.append(transition_event)

write_neo(args.output, block)
Expand All @@ -204,9 +216,8 @@ def plot_minima(asig, event, channel, maxima_threshold_window,
for channel in args.plot_channels:
plot_minima(asig=time_slice(asig, args.plot_tstart, args.plot_tstop),
event=time_slice(transition_event, args.plot_tstart, args.plot_tstop),
threshold_asig=time_slice(threshold_asig, args.plot_tstart, args.plot_tstop),
channel=int(channel),
maxima_threshold_window = args.maxima_threshold_window,
maxima_threshold_fraction = args.maxima_threshold_fraction,
min_peak_distance = args.min_peak_distance)
min_peak_distance=args.min_peak_distance)
output_path = args.img_dir / args.img_name.replace('_channel0', f'_channel{channel}')
save_plot(output_path)