Skip to content

Commit

Permalink
wip: add careful spike processing option. If this is the case, then s…
Browse files Browse the repository at this point in the history
…pikes in pre, post and stimulus interval are processed seperately
  • Loading branch information
jnsbck committed Nov 3, 2023
1 parent 5ddc119 commit 9806dbf
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 6 deletions.
2 changes: 1 addition & 1 deletion ephyspy/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,6 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.

VERSION = (0, 0, 6)
VERSION = (0, 0, 7)

__version__ = ".".join(map(str, VERSION))
67 changes: 62 additions & 5 deletions ephyspy/sweeps.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.pyplot import Axes
from numpy import ndarray
from pandas import DataFrame
Expand Down Expand Up @@ -58,20 +59,25 @@ def __init__(
start: Optional[float] = None,
end: Optional[float] = None,
metadata: Dict = {},
careful_spike_processing: bool = False,
**kwargs,
):
"""
Args:
metadata (dict, optional): Metadata for the sweep. Defaults to None.
The metadata can be used to set hyperparameters for features or
store identifying information, such as cell id etc..
careful_spike_processing (bool, optional): Whether to perform spike
processing carefully, i.e. detect pre-, post- and during-stimulus
spikes seperately. Typically leads to less errors, but can be slower.
*args: Additional arguments for EphysSweepFeatureExtractor.
**kwargs: Additional keyword arguments for EphysSweepFeatureExtractor.
"""
super().__init__(t=t, v=v, i=i, start=start, end=end, **kwargs)
self.metadata = metadata
self.added_spike_features = {}
self.features = {}
self.careful_spike_processing = careful_spike_processing
self._init_sweep()

def _init_sweep(self):
Expand All @@ -86,8 +92,8 @@ def _init_sweep(self):
self.t = self.t[:idx_end]
self.v = self.v[:idx_end]
self.i = self.i[:idx_end]
self.start = self.t[0]
self.end = self.t[-1]
self.start = self.t[0] if self.start is None else self.start
self.end = self.t[-1] if self.end is None else self.end

def add_spike_feature(self, feature_name: str, feature_func: Callable):
"""Add a new spike feature to the extractor.
Expand Down Expand Up @@ -130,9 +136,55 @@ def _process_added_spike_features(self):
def process_spikes(self):
"""Perform spike-related feature analysis, which includes added spike
features not part of the original AllenSDK implementation."""
self._process_individual_spikes()
self._process_spike_related_features()
self._process_added_spike_features()

def run_spike_processing():
self._process_individual_spikes()
self._process_spike_related_features()
self._process_added_spike_features()
if not self._spikes_df.empty:
self._spikes_df["T_start"] = self.start
self._spikes_df["T_end"] = self.end

where_stimulus = self.i != 0
if np.any(where_stimulus):
stim_onset, stim_end = self.t[where_stimulus][[0, -1]]
same_t = lambda t1, t2, tol=1e-3: (
abs(t1 - t2) < tol if t1 is not None and t2 is not None else False
)
else:
stim_onset, stim_end = None, None
same_t = lambda t1, t2, tol=1e-3: False

if (
same_t(stim_onset, self.start)
and same_t(stim_end, self.end)
or not self.careful_spike_processing
):
run_spike_processing()
else:
t_intervals = [self.t[0], stim_onset, stim_end, self.t[-1]]
spike_dfs = []
orig_interval = (self.start, self.end)
for t_start, t_end in zip(t_intervals[:-1], t_intervals[1:]):
self.start = t_start
self.end = t_end
run_spike_processing()
spike_dfs.append(self._spikes_df)
del self._spikes_df

self.start, self.end = orig_interval
self._spikes_df = pd.concat(spike_dfs)

# remove duplicate spikes at interval boundaries
if not self._spikes_df.empty:
T_lockout = 1e-3
boundary_idxs = np.where(self._spikes_df["T_start"].diff() > 0)[0]
for idx in boundary_idxs:
ap1_t, ap2_t = self._spikes_df.iloc[[idx - 1, idx]]["threshold_t"]
if ap1_t + T_lockout > ap2_t:
rm_idx = idx if stim_onset < ap1_t < stim_end else idx - 1
self._spikes_df.drop(rm_idx, inplace=True)
self._spikes_df.reset_index(drop=True, inplace=True)

def get_features(self, recompute: bool = False) -> Dict[str, float]:
"""Compute all features that have been added to the `EphysSweep` instance.
Expand Down Expand Up @@ -309,6 +361,7 @@ def __init__(
start: Optional[Union[List, ndarray, float]] = None,
end: Optional[Union[List, ndarray, float]] = None,
metadata: Dict = {},
careful_spike_processing: bool = False,
*args,
**kwargs,
):
Expand All @@ -320,6 +373,9 @@ def __init__(
t_start (ndarray, optional): Start time for each sweep.
t_end (ndarray, optional): End time for each sweep.
metadata (dict, optional): Metadata for the sweep set.
careful_spike_processing (bool, optional): Whether to perform spike
processing carefully, i.e. detect pre-, post- and during-stimulus
spikes seperately. Typically leads to less errors, but can be slower.
*args: Additional arguments for EphysSweepSetFeatureExtractor.
**kwargs: Additional keyword arguments for EphysSweepSetFeatureExtractor.
"""
Expand All @@ -337,6 +393,7 @@ def __init__(
self.metadata = metadata
for sweep in self.sweeps():
sweep.metadata = metadata
sweep.careful_spike_processing = careful_spike_processing
self.features = {}

@property
Expand Down

0 comments on commit 9806dbf

Please sign in to comment.