From 9806dbf80725947db3b0b3cb21cd9d8818f0c615 Mon Sep 17 00:00:00 2001 From: jnsbck-uni Date: Fri, 3 Nov 2023 13:25:59 +0100 Subject: [PATCH] wip: add careful spike processing option. If this is the case, then spikes in pre, post and stimulus interval are processed seperately --- ephyspy/__version__.py | 2 +- ephyspy/sweeps.py | 67 ++++++++++++++++++++++++++++++++++++++---- 2 files changed, 63 insertions(+), 6 deletions(-) diff --git a/ephyspy/__version__.py b/ephyspy/__version__.py index c3707ef..080741b 100644 --- a/ephyspy/__version__.py +++ b/ephyspy/__version__.py @@ -14,6 +14,6 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . -VERSION = (0, 0, 6) +VERSION = (0, 0, 7) __version__ = ".".join(map(str, VERSION)) diff --git a/ephyspy/sweeps.py b/ephyspy/sweeps.py index 64d420e..30a4a1c 100644 --- a/ephyspy/sweeps.py +++ b/ephyspy/sweeps.py @@ -20,6 +20,7 @@ import matplotlib.pyplot as plt import numpy as np +import pandas as pd from matplotlib.pyplot import Axes from numpy import ndarray from pandas import DataFrame @@ -58,6 +59,7 @@ def __init__( start: Optional[float] = None, end: Optional[float] = None, metadata: Dict = {}, + careful_spike_processing: bool = False, **kwargs, ): """ @@ -65,6 +67,9 @@ def __init__( metadata (dict, optional): Metadata for the sweep. Defaults to None. The metadata can be used to set hyperparameters for features or store identifying information, such as cell id etc.. + careful_spike_processing (bool, optional): Whether to perform spike + processing carefully, i.e. detect pre-, post- and during-stimulus + spikes seperately. Typically leads to less errors, but can be slower. *args: Additional arguments for EphysSweepFeatureExtractor. **kwargs: Additional keyword arguments for EphysSweepFeatureExtractor. """ @@ -72,6 +77,7 @@ def __init__( self.metadata = metadata self.added_spike_features = {} self.features = {} + self.careful_spike_processing = careful_spike_processing self._init_sweep() def _init_sweep(self): @@ -86,8 +92,8 @@ def _init_sweep(self): self.t = self.t[:idx_end] self.v = self.v[:idx_end] self.i = self.i[:idx_end] - self.start = self.t[0] - self.end = self.t[-1] + self.start = self.t[0] if self.start is None else self.start + self.end = self.t[-1] if self.end is None else self.end def add_spike_feature(self, feature_name: str, feature_func: Callable): """Add a new spike feature to the extractor. @@ -130,9 +136,55 @@ def _process_added_spike_features(self): def process_spikes(self): """Perform spike-related feature analysis, which includes added spike features not part of the original AllenSDK implementation.""" - self._process_individual_spikes() - self._process_spike_related_features() - self._process_added_spike_features() + + def run_spike_processing(): + self._process_individual_spikes() + self._process_spike_related_features() + self._process_added_spike_features() + if not self._spikes_df.empty: + self._spikes_df["T_start"] = self.start + self._spikes_df["T_end"] = self.end + + where_stimulus = self.i != 0 + if np.any(where_stimulus): + stim_onset, stim_end = self.t[where_stimulus][[0, -1]] + same_t = lambda t1, t2, tol=1e-3: ( + abs(t1 - t2) < tol if t1 is not None and t2 is not None else False + ) + else: + stim_onset, stim_end = None, None + same_t = lambda t1, t2, tol=1e-3: False + + if ( + same_t(stim_onset, self.start) + and same_t(stim_end, self.end) + or not self.careful_spike_processing + ): + run_spike_processing() + else: + t_intervals = [self.t[0], stim_onset, stim_end, self.t[-1]] + spike_dfs = [] + orig_interval = (self.start, self.end) + for t_start, t_end in zip(t_intervals[:-1], t_intervals[1:]): + self.start = t_start + self.end = t_end + run_spike_processing() + spike_dfs.append(self._spikes_df) + del self._spikes_df + + self.start, self.end = orig_interval + self._spikes_df = pd.concat(spike_dfs) + + # remove duplicate spikes at interval boundaries + if not self._spikes_df.empty: + T_lockout = 1e-3 + boundary_idxs = np.where(self._spikes_df["T_start"].diff() > 0)[0] + for idx in boundary_idxs: + ap1_t, ap2_t = self._spikes_df.iloc[[idx - 1, idx]]["threshold_t"] + if ap1_t + T_lockout > ap2_t: + rm_idx = idx if stim_onset < ap1_t < stim_end else idx - 1 + self._spikes_df.drop(rm_idx, inplace=True) + self._spikes_df.reset_index(drop=True, inplace=True) def get_features(self, recompute: bool = False) -> Dict[str, float]: """Compute all features that have been added to the `EphysSweep` instance. @@ -309,6 +361,7 @@ def __init__( start: Optional[Union[List, ndarray, float]] = None, end: Optional[Union[List, ndarray, float]] = None, metadata: Dict = {}, + careful_spike_processing: bool = False, *args, **kwargs, ): @@ -320,6 +373,9 @@ def __init__( t_start (ndarray, optional): Start time for each sweep. t_end (ndarray, optional): End time for each sweep. metadata (dict, optional): Metadata for the sweep set. + careful_spike_processing (bool, optional): Whether to perform spike + processing carefully, i.e. detect pre-, post- and during-stimulus + spikes seperately. Typically leads to less errors, but can be slower. *args: Additional arguments for EphysSweepSetFeatureExtractor. **kwargs: Additional keyword arguments for EphysSweepSetFeatureExtractor. """ @@ -337,6 +393,7 @@ def __init__( self.metadata = metadata for sweep in self.sweeps(): sweep.metadata = metadata + sweep.careful_spike_processing = careful_spike_processing self.features = {} @property