From 1cee34e231de909d51b87922af2a2b73ccfcbc9e Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Wed, 1 Jun 2022 10:53:31 +0200 Subject: [PATCH 01/16] Initial commit for the upcoming pull request, which will introduce an online version of the Unitary Event Analysis into 'elephant' --- elephant/online.py | 738 ++++++++++++++++++ .../test/concurrent_test_runner_for_uea.py | 18 + elephant/test/test_online.py | 549 +++++++++++++ 3 files changed, 1305 insertions(+) create mode 100644 elephant/online.py create mode 100644 elephant/test/concurrent_test_runner_for_uea.py create mode 100644 elephant/test/test_online.py diff --git a/elephant/online.py b/elephant/online.py new file mode 100644 index 000000000..097b05ed2 --- /dev/null +++ b/elephant/online.py @@ -0,0 +1,738 @@ +import warnings +from collections import defaultdict, deque + +import neo +import numpy as np +import quantities as pq +import scipy.special as sc + +import elephant.conversion as conv +from elephant.unitary_event_analysis import * +from elephant.unitary_event_analysis import _winpos, _bintime, _UE + + +class OnlineUnitaryEventAnalysis: + """ + Online version of the unitary event analysis (UEA). + + This class facilitates methods to perform the unitary event analysis in an + online manner, i.e. data generation and data analysis happen concurrently. + The attributes of this class are eiter partial results or descriptive + parameters of the UEA. + + Parameters + ---------- + bw_size : pq.Quantity + Size of the bin window, which is used to bin the spike trains. + trigger_events : pq.Quantity + Quantity array of time points around which the trials are defined. + The time interval of a trial is defined as follows: + [trigger_event - trigger_pre_size, trigger_event + trigger_post_size] + trigger_pre_size : pq.Quantity + Interval size before the trigger event. It is used with + 'trigger_post_size' to define the trial. + trigger_post_size : pq.Quantity + Interval size after the trigger event. It is used with + 'trigger_pre_size' to define the trial. + saw_size : pq.Quantity + Size of the sliding analysis window, which is used to perform the UEA + on the trial segments. It advances with a step size defined by + 'saw_step'. + saw_step : pq.Quantity + Size of the step which is used to advance the sliding analysis window to + its next position / next trial segment to analyze. + n_neurons : int + Number of neurons which are analyzed with the UEA. + pattern_hash : int or list of int or None + A list of interested patterns in hash values (see `hash_from_pattern` + and `inverse_hash_from_pattern` functions in + 'elephant.unitary_event_analysis'). If None, all neurons are + participated. + Default: None + time_unit : pq.Quantity + This time unit is used for all calculations which requires a time + quantity. (Default: [s]) + save_n_trials : (positive) int + The number of trials `n` which will be saved after their analysis with a + queue following the FIFO strategy (first in, first out), i.e. only + the last `n` analyzed trials will be stored. (default: None) + + Attributes + ---------- + data_available_in_mv : boolean + Reflects the status of spike trains in the memory window. It is True, + when spike trains are in the memory window which were not yet analyzed. + Otherwise, it is False. + waiting_for_new_trigger : boolean + Reflects the status of the updating-algorithm in 'update_uea()'. + It is `True`, when the algorithm is in the state of pre- / post-trial + analysis, i.e. it expects the arrival of the next trigger event which + will define the next trial. Otherwise, it is `False`, when the algorithm + is within the analysis of the current trial, i.e. it does not need + the next trigger event at this moment. + trigger_events_left_over : boolean + Reflects the status of the trial defining events in the 'trigger_events' + list. It is `True`, when there are events left, which were not analyzed + yet. Otherwise, it is `False`. + mw : list of lists + Contains for each neuron the spikes which are currently available in + the memory window. + * 0-axis --> Neurons + * 1-axis --> Spike times + tw_size : pq.Quantity + The size of the trial window. It is the sum of 'trigger_pre_size' and + 'trigger_post_size'. + tw : list of lists + Contains for each neuron the spikes which belong to the current trial + and are available in the memory window. + * 0-axis --> Neurons + * 1-axis --> Spike times + tw_counter : int + Counts how many trails are yet analyzed. + n_bins : int + Number of bins which are used for the binning of the spikes of a trial. + bw : np.array of booleans + A binned representation of the current trial window. `True` indicates + the presence of a spike in the bin and `False` indicates absence of + a spike. + * 0-axis --> Neurons + * 1-axis --> Index position of the bin + saw_pos_counter : int + Represents the current position of the sliding analysis window. + n_windows : int + Total number of positions of the sliding analysis window. + n_trials : int + Total number of trials to analyze. + n_hashes: int + Number of patterns (coded as hashes) to be analyzed. (see also + 'elephant.unitary_event_analysis.hash_from_pattern()') + method : string + The method with which to compute the unitary events: + * 'analytic_TrialByTrial': calculate the analytical expectancy + on each trial, then sum over all trials + * 'analytic_TrialAverage': calculate the expectancy by averaging + over trials (cf. Gruen et al. 2003); + * 'surrogate_TrialByTrial': calculate the distribution of expected + coincidences by spike time randomization in each trial and + sum over trials + 'analytic_TrialAverage' and 'surrogate_TrialByTrial' are not supported + yet. + Default: 'analytic_trialByTrial' + n_surrogates : int + Number of surrogates which would be used when 'surrogate_TrialByTrial' + is chosen as 'method'. Yet 'surrogate_TrialByTrial' is not supported. + input_parameters : dict + Dictionary of the input parameters which would be used for calling + the offline version of UEA for the same data to get the same results. + Js : np.ndarray + JointSurprise of different given patterns within each window. + * 0-axis --> different window + * 1-axis --> different pattern hash + n_exp : np.ndarray + The expected number of coincidences of each pattern within each window. + * 0-axis --> different window + * 1-axis --> different pattern hash + n_emp : np.ndarray + The empirical number of coincidences of each pattern within each window. + * 0-axis --> different window + * 1-axis --> different pattern hash + rate_avg : np.ndarray + The average firing rate of each neuron of each pattern within + each window. + * 0-axis --> different window + * 1-axis --> different pattern hash + * 2-axis --> different neuron + indices : defaultdict + Dictionary contains for each trial the indices of pattern + within each window. + + Methods + ------- + get_results() + Returns the result dictionary with the following class attribute names + as keys and the corresponding attribute values as the complementary + value for the key: (see also Attributes section for respective key + descriptions) + * 'Js' + * 'indices' + * 'n_emp' + * 'n_exp' + * 'rate_avg' + * 'input_parameters' + update_uea(spiketrains, events) + Updates the entries of the result dictionary by processing the + new arriving 'spiketrains' and trial defining trigger 'events'. + reset(bw_size, trigger_events, trigger_pre_size, trigger_post_size, + saw_size, saw_step, n_neurons, pattern_hash) + Resets all class attributes to their initial (default) value. It is + actually a re-initialization which allows parameter adjustments. + + Returns + ------- + see 'get_results()' in Methods section + + Notes + ----- + Common abbreviations which are used in both code and documentation: + bw = bin window + tw = trial window + saw = sliding analysis window + idw = incoming data window + mw = memory window + + """ + + def __init__(self, bw_size=0.005 * pq.s, trigger_events=None, + trigger_pre_size=0.5 * pq.s, trigger_post_size=0.5 * pq.s, + saw_size=0.1 * pq.s, saw_step=0.005 * pq.s, n_neurons=2, + pattern_hash=None, time_unit=1 * pq.s, save_n_trials=None): + """ + Constructor. Initializes all attributes of the new instance. + """ + # state controlling booleans for the updating algorithm + self.data_available_in_mv = None + self.waiting_for_new_trigger = True + self.trigger_events_left_over = True + + # save constructor parameters + if time_unit.units != (pq.s and pq.ms): + warnings.warn(message=f"Unusual time units like {time_unit} can " + f"cause numerical imprecise results. " + f"Use `ms` or `s` instead!", + category=UserWarning) + self.time_unit = time_unit + self.bw_size = bw_size.rescale(self.time_unit) + if trigger_events is None: + self.trigger_events = [] + else: + self.trigger_events = trigger_events.rescale( + self.time_unit).tolist() + self.trigger_pre_size = trigger_pre_size.rescale(self.time_unit) + self.trigger_post_size = trigger_post_size.rescale(self.time_unit) + self.saw_size = saw_size.rescale(self.time_unit) # multiple of bw_size + self.saw_step = saw_step.rescale(self.time_unit) # multiple of bw_size + self.n_neurons = n_neurons + if pattern_hash is None: + pattern = [1] * n_neurons + self.pattern_hash = hash_from_pattern(pattern) + if np.issubdtype(type(self.pattern_hash), np.integer): + self.pattern_hash = [int(self.pattern_hash)] + self.save_n_trials = save_n_trials + + # initialize helper variables for the memory window (mw) + self.mw = [[] for _ in range(self.n_neurons)] # array of all spiketimes + + # initialize helper variables for the trial window (tw) + self.tw_size = self.trigger_pre_size + self.trigger_post_size + self.tw = [[] for _ in range(self.n_neurons)] # pointer to slice of mw + self.tw_counter = 0 + + # initialize helper variables for the bin window (bw) + self.n_bins = None + self.bw = None # binned copy of tw + + # initialize helper variable for the sliding analysis window (saw) + self.saw_pos_counter = 0 + self.n_windows = int(np.round( + (self.tw_size - self.saw_size + self.saw_step) / self.saw_step)) + + # determine the number trials and the number of patterns (hashes) + self.n_trials = len(self.trigger_events) + self.n_hashes = len(self.pattern_hash) + # (optional) save last `n` analysed trials for visualization + if self.save_n_trials is not None: + self.all_trials = deque(maxlen=self.save_n_trials) + + # save input parameters as dict like the offline version of UEA it does + # to facilitate a later comparison of the used parameters + self.method = 'analytic_TrialByTrial' + self.n_surrogates = 100 + self.input_parameters = dict(pattern_hash=self.pattern_hash, + bin_size=self.bw_size.rescale(pq.ms), + win_size=self.saw_size.rescale(pq.ms), + win_step=self.saw_step.rescale(pq.ms), + method=self.method, + t_start=0 * time_unit, + t_stop=self.tw_size, + n_surrogates=self.n_surrogates) + + # initialize the intermediate result arrays for the joint surprise (js), + # number of expected coincidences (n_exp), number of empirically found + # coincidences (n_emp), rate average of the analyzed neurons (rate_avg), + # as well as the indices of the saw position where coincidences appear + self.Js, self.n_exp, self.n_emp = np.zeros( + (3, self.n_windows, self.n_hashes), dtype=np.float64) + self.rate_avg = np.zeros( + (self.n_windows, self.n_hashes, self.n_neurons), dtype=np.float64) + self.indices = defaultdict(list) + + def get_all_saved_trials(self): + """ + Return the last `n`-trials which were analyzed. + + `n` is the number of trials which were saved after their analysis + using a queue with the FIFO strategy (first in, first out). + + Returns + ------- + : list of list of neo.SpikeTrain + A nested list of trials, neurons and their neo.SpikeTrain objects, + respectively. + """ + return list(self.all_trials) + + def get_results(self): + """ + Return result dictionary. + + Prepares the dictionary entries by reshaping them into the correct + shape with the correct dtype. + + Returns + ------- + : dict + Dictionary with the following class attribute names + as keys and the corresponding attribute values as the complementary + value for the key: (see also Attributes section for respective key + descriptions) + * 'Js' + * 'indices' + * 'n_emp' + * 'n_exp' + * 'rate_avg' + * 'input_parameters' + + """ + for key in self.indices.keys(): + self.indices[key] = np.hstack(self.indices[key]).flatten() + self.n_exp /= (self.saw_size / self.bw_size) + p = self._pval(self.n_emp.astype(np.float64), + self.n_exp.astype(np.float64)).flatten() + self.Js = jointJ(p) + self.rate_avg = (self.rate_avg * (self.saw_size / self.bw_size)) / \ + (self.saw_size * self.n_trials) + return { + 'Js': self.Js.reshape( + (self.n_windows, self.n_hashes)).astype(np.float32), + 'indices': self.indices, + 'n_emp': self.n_emp.reshape( + (self.n_windows, self.n_hashes)).astype(np.float32), + 'n_exp': self.n_exp.reshape( + (self.n_windows, self.n_hashes)).astype(np.float32), + 'rate_avg': self.rate_avg.reshape( + (self.n_windows, self.n_hashes, self.n_neurons)).astype( + np.float32), + 'input_parameters': self.input_parameters} + + def _pval(self, n_emp, n_exp): + """ + Calculates the probability of detecting 'n_emp' or more coincidences + based on a distribution with sole parameter 'n_exp'. + + To calculate this probability, the upper incomplete gamma function is + used. + + Parameters + ---------- + n_emp : int + Number of empirically observed coincidences. + n_exp : float + Number of theoretically expected coincidences. + + Returns + ------- + p : float + Probability of finding 'n_emp' or more coincidences based on a + distribution with sole parameter 'n_exp' + + """ + p = 1. - sc.gammaincc(n_emp, n_exp) + return p + + def _save_idw_into_mw(self, idw): + """ + Save in-incoming data window (IDW) into memory window (MW). + + This function appends for each neuron all the spikes which are arriving + with 'idw' into the respective sub-list of 'mv'. + + Parameters + --------- + idw : list of pq.Quantity arrays + * 0-axis --> Neurons + * 1-axis --> Spike times + + """ + for i in range(self.n_neurons): + self.mw[i] += idw[i].tolist() + + def _move_mw(self, new_t_start): + """ + Move memory window. + + This method moves the memory window, i.e. it removes for each neuron + all the spikes that occurred before the time point 'new_t_start'. + Spikes which occurred after 'new_t_start' will be kept. + + Parameters + ---------- + new_t_start : pq.Quantity + New start point in time of the memory window. Spikes which occurred + after this time point will be kept, otherwise removed. + + """ + for i in range(self.n_neurons): + idx = np.where(new_t_start > self.mw[i])[0] + # print(f"idx = {idx}") + if not len(idx) == 0: # move mv + self.mw[i] = self.mw[i][idx[-1] + 1:] + else: # keep mv + self.data_available_in_mv = False + + def _define_tw(self, trigger_event): + """ + Define trial window (TW) based on a trigger event. + + This method defines the trial window around the 'trigger_event', i.e. + it sets the start and stop of the trial, so that it covers the + following interval: + [trigger_event - trigger_pre_size, trigger_event + trigger_post_size] + Then it collects for each neuron all spike times from the memory window + which are within this interval and puts them into the trial window. + + Parameters + ---------- + trigger_event : pq.Quantity + Time point around which the trial will be defined. + + """ + self.trial_start = trigger_event - self.trigger_pre_size + self.trial_stop = trigger_event + self.trigger_post_size + for i in range(self.n_neurons): + self.tw[i] = [t for t in self.mw[i] + if (self.trial_start <= t) & (t <= self.trial_stop)] + + def _check_tw_overlap(self, current_trigger_event, next_trigger_event): + """ + Check if successive trials do overlap each other. + + This method checks whether two successive trials are overlapping + each other. To do this it compares the stop time of the precedent + trial and the start time of the subsequent trial. An overlap is present + if start time of the subsequent trial is before the stop time + of the precedent trial. + + Parameters + ---------- + current_trigger_event : pq.Quantity + Time point around which the current / precedent trial was defined. + next_trigger_event : pq.Quantity + Time point around which the next / subsequent trial will be defined. + + Returns + ------- + : boolean + If an overlap exists, return `True`. Otherwise, `False`. + + """ + if current_trigger_event + self.trigger_post_size > \ + next_trigger_event - self.trigger_pre_size: + return True + else: + return False + + def _apply_bw_to_tw(self): + """ + Apply bin window (BW) to trial window (TW). + + Perform the binning and clipping procedure on the trial window, i.e. + if at least one spike is within a bin, it is occupied and + if no spike is within a bin, it is empty. + + """ + self.n_bins = int(((self.trial_stop - self.trial_start) / + self.bw_size).simplified.item()) + self.bw = np.zeros((1, self.n_neurons, self.n_bins), dtype=np.int32) + spiketrains = [neo.SpikeTrain(np.array(st) * self.time_unit, + t_start=self.trial_start, + t_stop=self.trial_stop) + for st in self.tw] + bs = conv.BinnedSpikeTrain(spiketrains, t_start=self.trial_start, + t_stop=self.trial_stop, + bin_size=self.bw_size) + self.bw = bs.to_bool_array() + + def _set_saw_positions(self, t_start, t_stop, win_size, win_step, bin_size): + """ + Set positions of the sliding analysis window (SAW). + + Determines the positions of the sliding analysis window with respect to + the used window size 'win_size' and the advancing step 'win_step'. Also + converts this time points into bin-units, i.e. into multiple of the + 'bin_size' which facilitates indexing in upcoming steps. + + Warns + ----- + UserWarning: + * if the ratio between the 'win_size' and 'bin_size' is not + an integer + * if the ratio between the 'win_step' and 'bin_size' is not + an integer + + """ + self.t_winpos = _winpos(t_start, t_stop, win_size, win_step, + position='left-edge') + while len(self.t_winpos) != self.n_windows: + if len(self.t_winpos) > self.n_windows: + self.t_winpos = _winpos(t_start, t_stop - win_step / 2, + win_size, + win_step, position='left-edge') + else: + self.t_winpos = _winpos(t_start, t_stop + win_step / 2, + win_size, + win_step, position='left-edge') + self.t_winpos_bintime = _bintime(self.t_winpos, bin_size) + self.winsize_bintime = _bintime(win_size, bin_size) + self.winstep_bintime = _bintime(win_step, bin_size) + if self.winsize_bintime * bin_size != win_size: + warnings.warn(f"The ratio between the win_size ({win_size}) and the" + f" bin_size ({bin_size}) is not an integer") + if self.winstep_bintime * bin_size != win_step: + warnings.warn(f"The ratio between the win_step ({win_step}) and the" + f" bin_size ({bin_size}) is not an integer") + + def _move_saw_over_tw(self, t_stop_idw): + """ + Move sliding analysis window (SAW) over trial window (TW). + + This method iterates over each sliding analysis window position and + applies at each position the unitary event analysis, i.e. within each + window it counts the empirically found coincidences and saves their + indices where they appeared, calculates the expected number of + coincidences and determines the firing rates of the neurons. + The respective results are then used to update the class attributes + 'n_emp', 'n_exp', 'rate_avg' and 'indices'. + + Notes + ----- + The 'Js' attribute is not continuously updated, because the + joint-surprise is determined just when the user calls 'get_results()'. + This is due to the dependency of the distribution from which 'Js' is + calculated on the attributes 'n_emp' and 'n_exp'. Updating / changing + 'n_emp' and 'n_exp' changes also this distribution, so that it not any + more possible to simply sum the joint-surprise values of different + trials at the same sliding analysis window position, because they were + based on different distributions. + + """ + # define saw positions + self._set_saw_positions( + t_start=self.trial_start, t_stop=self.trial_stop, + win_size=self.saw_size, win_step=self.saw_step, + bin_size=self.bw_size) + + # iterate over saw positions + for i in range(self.saw_pos_counter, self.n_windows): + p_realtime = self.t_winpos[i] + p_bintime = self.t_winpos_bintime[i] - self.t_winpos_bintime[0] + # check if saw filled with data + if p_realtime + self.saw_size <= t_stop_idw: # saw is filled + mat_win = np.zeros((1, self.n_neurons, self.winsize_bintime)) + n_bins_in_current_saw = self.bw[ + :, + p_bintime:p_bintime + self.winsize_bintime].shape[ + 1] + if n_bins_in_current_saw < self.winsize_bintime: + mat_win[0] += np.pad( + self.bw[:, p_bintime:p_bintime + self.winsize_bintime], + (0, self.winsize_bintime - n_bins_in_current_saw), + "minimum")[0:2] + else: + mat_win[0] += \ + self.bw[:, p_bintime:p_bintime + self.winsize_bintime] + Js_win, rate_avg, n_exp_win, n_emp_win, indices_lst = _UE( + mat_win, pattern_hash=self.pattern_hash, + method=self.method, n_surrogates=self.n_surrogates) + self.rate_avg[i] += rate_avg + self.n_exp[i] += (np.round( + n_exp_win * (self.saw_size / self.bw_size))).astype(int) + self.n_emp[i] += n_emp_win + self.indices_lst = indices_lst + if len(self.indices_lst[0]) > 0: + self.indices[f"trial{self.tw_counter}"].append( + self.indices_lst[0] + p_bintime) + else: # saw is empty / half-filled -> pause iteration + self.saw_pos_counter = i + self.data_available_in_mv = False + break + if i == self.n_windows - 1: # last SAW position finished + self.saw_pos_counter = 0 + # move MV after SAW is finished with analysis of one trial + self._move_mw(new_t_start=self.trigger_events[ + self.tw_counter] + self.tw_size) + # save analysed trial for visualization + if self.save_n_trials: + _trial_start = 0 * pq.s + _trial_stop = self.tw_size + _offset = self.trigger_events[self.tw_counter] - \ + self.trigger_pre_size + normalized_spike_times = [] + for n in range(self.n_neurons): + normalized_spike_times.append( + np.array(self.tw[n]) * self.time_unit - _offset) + self.all_trials.append( + [neo.SpikeTrain(normalized_spike_times[m], + t_start=_trial_start, + t_stop=_trial_stop, + units=self.time_unit) + for m in range(self.n_neurons)]) + # reset bw + self.bw = np.zeros_like(self.bw) + if self.tw_counter <= self.n_trials - 1: + self.tw_counter += 1 + else: + self.waiting_for_new_trigger = True + self.trigger_events_left_over = False + self.data_available_in_mv = False + print(f"tw_counter = {self.tw_counter}") # DEBUG-aid + + def update_uea(self, spiketrains, events=None): + """ + Update unitary event analysis (UEA) with new arriving spike data from + the incoming data window (IDW). + + This method orchestrates the updating process. It saves the incoming + 'spiketrains' into the memory window (MW) and adds also the new trigger + 'events' into the 'trigger_events' list. Then depending on the state in + which the algorithm is, it processes the new 'spiketrains' respectivly. + There are two major states with each two substates between the algorithm + is switching. + + Warns + ----- + UserWarning + * if an overlap between successive trials exists, spike data + of these trials will be analysed twice. The user should adjust + the trigger events and/or the trial window size to increase + the interval between successive trials to avoid an overlap. + + Notes + ----- + Short summary of the different algorithm major states / substates: + 1. pre/post trial analysis: algorithm is waiting for IDW with + new trigger event + 1.1. IDW contains new trigger event + 1.2. IDW does not contain new trigger event + 2. within trial analysis: algorithm is waiting for IDW with + spikes of current trial + 2.1. IDW contains new trigger event + 2.2. IDW does not contain new trigger event, it just has new spikes + of the current trial + + """ + # rescale spiketrains to time_unit + spiketrains = [st.rescale(self.time_unit) + if st.t_start.units == st.units == st.t_stop + else st.rescale(st.units).rescale(self.time_unit) + for st in spiketrains] + + if events is None: + events = np.array([]) + if len(events) > 0: + for event in events: + if event not in self.trigger_events: + self.trigger_events.append(event.rescale(self.time_unit)) + self.trigger_events.sort() + self.n_trials = len(self.trigger_events) + # save incoming spikes (IDW) into memory (MW) + self._save_idw_into_mw(spiketrains) + # extract relevant time information + idw_t_start = spiketrains[0].t_start + idw_t_stop = spiketrains[0].t_stop + + # analyse all trials which are available in the memory + self.data_available_in_mv = True + while self.data_available_in_mv: + if self.tw_counter == self.n_trials: + break + if self.n_trials == 0: + current_trigger_event = np.inf * self.time_unit + next_trigger_event = np.inf * self.time_unit + else: + current_trigger_event = self.trigger_events[self.tw_counter] + if self.tw_counter <= self.n_trials - 2: + next_trigger_event = self.trigger_events[ + self.tw_counter + 1] + else: + next_trigger_event = np.inf * self.time_unit + + # # case 1: pre/post trial analysis, + # i.e. waiting for IDW with new trigger event + if self.waiting_for_new_trigger: + # # subcase 1: IDW contains trigger event + if (idw_t_start <= current_trigger_event) & \ + (current_trigger_event <= idw_t_stop): + self.waiting_for_new_trigger = False + if self.trigger_events_left_over: + # define TW around trigger event + self._define_tw(trigger_event=current_trigger_event) + # apply BW to available data in TW + self._apply_bw_to_tw() + # move SAW over available data in TW + self._move_saw_over_tw(t_stop_idw=idw_t_stop) + else: + pass + # # subcase 2: IDW does not contain trigger event + else: + self._move_mw( + new_t_start=idw_t_stop - self.trigger_pre_size) + + # # Case 2: within trial analysis, + # i.e. waiting for new IDW with spikes of current trial + else: + # # Subcase 3: IDW contains new trigger event + if (idw_t_start <= next_trigger_event) & \ + (next_trigger_event <= idw_t_stop): + # check if an overlap between current / next trial exists + if self._check_tw_overlap( + current_trigger_event=current_trigger_event, + next_trigger_event=next_trigger_event): + warnings.warn( + f"Data in trial {self.tw_counter} will be analysed " + f"twice! Adjust the trigger events and/or " + f"the trial window size.", UserWarning) + else: # no overlap exists + pass + # # Subcase 4: IDW does not contain trigger event, + # i.e. just new spikes of the current trial + else: + pass + if self.trigger_events_left_over: + # define trial TW around trigger event + self._define_tw(trigger_event=current_trigger_event) + # apply BW to available data in TW + self._apply_bw_to_tw() + # move SAW over available data in TW + self._move_saw_over_tw(t_stop_idw=idw_t_stop) + else: + pass + + def reset(self, bw_size=0.005 * pq.s, trigger_events=None, + trigger_pre_size=0.5 * pq.s, trigger_post_size=0.5 * pq.s, + saw_size=0.1 * pq.s, saw_step=0.005 * pq.s, n_neurons=2, + pattern_hash=None, time_unit=1 * pq.s): + """ + Resets all class attributes to their initial value. + + This reset is actually a re-initialization which allows parameter + adjustments, so that one instance of 'OnlineUnitaryEventAnalysis' can + be flexibly adjusted to changing experimental circumstances. + + Parameters + ---------- + (same as for the constructor; see docstring of constructor for details) + + """ + self.__init__(bw_size, trigger_events, trigger_pre_size, + trigger_post_size, saw_size, saw_step, n_neurons, + pattern_hash, time_unit) \ No newline at end of file diff --git a/elephant/test/concurrent_test_runner_for_uea.py b/elephant/test/concurrent_test_runner_for_uea.py new file mode 100644 index 000000000..e01fb810b --- /dev/null +++ b/elephant/test/concurrent_test_runner_for_uea.py @@ -0,0 +1,18 @@ +import unittest +from test_online import TestOnlineUnitaryEventAnalysis +from concurrencytest import ConcurrentTestSuite, fork_for_tests + +if __name__ == "__main__": + + loader = unittest.TestLoader() + suite = unittest.TestSuite() + + suite.addTest(loader.loadTestsFromTestCase(TestOnlineUnitaryEventAnalysis)) + runner = unittest.TextTestRunner(verbosity=3) + + # runs test sequentially + # result = runner.run(suite) + + # runs tests across 4 processes + concurrent_suite = ConcurrentTestSuite(suite, fork_for_tests(4)) + runner.run(concurrent_suite) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py new file mode 100644 index 000000000..dc419b5bb --- /dev/null +++ b/elephant/test/test_online.py @@ -0,0 +1,549 @@ +import random +import unittest +from collections import defaultdict + +import matplotlib.pyplot as plt +import neo +import numpy as np +import quantities as pq +import viziphant + +from elephant.datasets import download_datasets +from elephant.online import OnlineUnitaryEventAnalysis +from elephant.spike_train_generation import homogeneous_poisson_process +from elephant.unitary_event_analysis import jointJ_window_analysis + + +def _generate_spiketrains(freq, length, trigger_events, injection_pos, + trigger_pre_size, trigger_post_size, + time_unit=1*pq.s): + """ + Generate two spiketrains from a homogeneous Poisson process with + injected coincidences. + """ + st1 = homogeneous_poisson_process(rate=freq, + t_start=(0*pq.s).rescale(time_unit), + t_stop=length.rescale(time_unit)) + st2 = homogeneous_poisson_process(rate=freq, + t_start=(0*pq.s.rescale(time_unit)), + t_stop=length.rescale(time_unit)) + # inject 10 coincidences within a 0.1s interval for each trial + injection = (np.linspace(0, 0.1, 10)*pq.s).rescale(time_unit) + all_injections = np.array([]) + for i in trigger_events: + all_injections = np.concatenate( + (all_injections, (i+injection_pos)+injection), axis=0) * time_unit + st1 = st1.duplicate_with_new_data( + np.sort(np.concatenate((st1.times, all_injections)))*time_unit) + st2 = st2.duplicate_with_new_data( + np.sort(np.concatenate((st2.times, all_injections)))*time_unit) + + # stack spiketrains by trial + st1_stacked = [st1.time_slice( + t_start=i - trigger_pre_size, + t_stop=i + trigger_post_size).time_shift(-i + trigger_pre_size) + for i in trigger_events] + st2_stacked = [st2.time_slice( + t_start=i - trigger_pre_size, + t_stop=i + trigger_post_size).time_shift(-i + trigger_pre_size) + for i in trigger_events] + spiketrains = np.stack((st1_stacked, st2_stacked), axis=1) + spiketrains = spiketrains.tolist() + + return spiketrains, st1, st2 + + +def _visualize_results_of_offline_and_online_uea( + spiketrains, online_trials, ue_dict_offline, ue_dict_online, alpha): + # rescale input-params 'bin_size', win_size' and 'win_step' to ms, + # because plot_ue() expects these parameters in ms + ue_dict_offline["input_parameters"]["bin_size"].units = pq.ms + ue_dict_offline["input_parameters"]["win_size"].units = pq.ms + ue_dict_offline["input_parameters"]["win_step"].units = pq.ms + viziphant.unitary_event_analysis.plot_ue( + spiketrains, Js_dict=ue_dict_offline, significance_level=alpha, + unit_real_ids=['1', '2'], suptitle="offline") + plt.show() + # reorder and rename indices-dict of ue_dict_online, if only the last + # n-trials were saved; indices-entries of unused trials are overwritten + if len(online_trials) < len(spiketrains): + _diff_n_trials = len(spiketrains) - len(online_trials) + for i in range(len(online_trials)): + ue_dict_online["indices"][f"trial{i}"] = \ + ue_dict_online["indices"].pop(f"trial{i+_diff_n_trials}") + viziphant.unitary_event_analysis.plot_ue( + online_trials, Js_dict=ue_dict_online, significance_level=alpha, + unit_real_ids=['1', '2'], suptitle="online") + plt.show() + + +def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, + length_remainder, events=None): + if events is None: + events = np.array([]) + for i in range(n_buffers): + buff_t_start = i * IDW_length + + if length_remainder > 1e-7 and i == n_buffers - 1: + buff_t_stop = i * IDW_length + length_remainder + else: + buff_t_stop = i * IDW_length + IDW_length + + events_in_buffer = np.array([]) + if len(events) > 0: + idx_events_in_buffer = (events >= buff_t_start) & \ + (events <= buff_t_stop) + events_in_buffer = events[idx_events_in_buffer].tolist() + events = events[np.logical_not(idx_events_in_buffer)] + + ouea.update_uea( + spiketrains=[ + st1.time_slice(t_start=buff_t_start, t_stop=buff_t_stop), + st2.time_slice(t_start=buff_t_start, t_stop=buff_t_stop)], + events=events_in_buffer) + print(f"#buffer = {i}") # DEBUG-aid + # # aid to create timelapses + # result_dict = ouea.get_results() + # viziphant.unitary_event_analysis.plot_ue( + # spiketrains[:i+1], Js_dict=result_dict, significance_level=0.05, + # unit_real_ids=['1', '2']) + # plt.savefig(f"plots/timelapse_UE/ue_real_data_buff_{i}.pdf") + + +def _load_real_data(n_trials, trial_length, time_unit): + # download data + repo_path = 'tutorials/tutorial_unitary_event_analysis/data/dataset-1.nix' + filepath = download_datasets(repo_path) + # load data and extract spiketrains + io = neo.io.NixIO(f"{filepath}", 'ro') + block = io.read_block() + spiketrains = [] + # each segment contains a single trial + for ind in range(len(block.segments)): + spiketrains.append(block.segments[ind].spiketrains) + # for each neuron: concatenate all trials to one long neo.Spiketrain + st1_long = [spiketrains[i].multiplexed[1][ + np.where(spiketrains[i].multiplexed[0] == 0)] + + i * trial_length + for i in range(len(spiketrains))] + st2_long = [spiketrains[i].multiplexed[1][ + np.where(spiketrains[i].multiplexed[0] == 1)] + + i * trial_length + for i in range(len(spiketrains))] + st1_concat = st1_long[0] + st2_concat = st2_long[0] + for i in range(1, len(st1_long)): + st1_concat = np.concatenate((st1_concat, st1_long[i])) + st2_concat = np.concatenate((st2_concat, st2_long[i])) + neo_st1 = neo.SpikeTrain((st1_concat / 1000) * pq.s, t_start=0 * pq.s, + t_stop=n_trials * trial_length).rescale(time_unit) + neo_st2 = neo.SpikeTrain((st2_concat / 1000) * pq.s, t_start=0 * pq.s, + t_stop=n_trials * trial_length).rescale(time_unit) + spiketrains = [[st[j].rescale(time_unit) for j in range(len(st))] for st in + spiketrains] + return spiketrains, neo_st1, neo_st2 + + +def _calculate_n_buffers(n_trials, tw_length, noise_length, idw_length): + _n_buffers_float = n_trials * (tw_length + noise_length) / idw_length + _n_buffers_int = int(_n_buffers_float) + _n_buffers_fraction = _n_buffers_float - _n_buffers_int + n_buffers = _n_buffers_int + 1 if _n_buffers_fraction > 1e-7 else \ + _n_buffers_int + length_remainder = idw_length * _n_buffers_fraction + return n_buffers, length_remainder + + +class TestOnlineUnitaryEventAnalysis(unittest.TestCase): + @classmethod + def setUpClass(cls): + np.random.seed(73) + cls.time_unit = 1 * pq.ms + cls.last_n_trials = 50 + + def setUp(self): + pass + + def _assert_equality_of_passed_and_saved_trials( + self, last_n_trials, passed_trials, saved_trials): + eps_float64 = np.finfo(np.float64).eps + n_neurons = len(passed_trials[0]) + with self.subTest("test 'trial' equality"): + for t in range(last_n_trials): + for n in range(n_neurons): + np.testing.assert_allclose( + actual=saved_trials[-t][n].rescale( + self.time_unit).magnitude, + desired=saved_trials[-t][n].rescale( + self.time_unit).magnitude, + atol=eps_float64, rtol=eps_float64) + + def _assert_equality_of_result_dicts(self, ue_dict_offline, ue_dict_online, + tol_dict_user): + eps_float64 = np.finfo(np.float64).eps + eps_float32 = np.finfo(np.float32).eps + tol_dict = {"atol_Js": eps_float64, "rtol_Js": eps_float64, + "atol_indices": eps_float64, "rtol_indices": eps_float64, + "atol_n_emp": eps_float64, "rtol_n_emp": eps_float64, + "atol_n_exp": eps_float64, "rtol_n_exp": eps_float64, + "atol_rate_avg": eps_float32, "rtol_rate_avg": eps_float32} + tol_dict.update(tol_dict_user) + + with self.subTest("test 'Js' equality"): + np.testing.assert_allclose( + actual=ue_dict_online["Js"], desired=ue_dict_offline["Js"], + atol=tol_dict["atol_Js"], + rtol=tol_dict["rtol_Js"]) + with self.subTest("test 'indices' equality"): + for key in ue_dict_offline["indices"].keys(): + np.testing.assert_allclose( + actual=ue_dict_online["indices"][key], + desired=ue_dict_offline["indices"][key], + atol=tol_dict["atol_indices"], + rtol=tol_dict["rtol_indices"]) + with self.subTest("test 'n_emp' equality"): + np.testing.assert_allclose( + actual=ue_dict_online["n_emp"], + desired=ue_dict_offline["n_emp"], + atol=tol_dict["atol_n_emp"], rtol=tol_dict["rtol_n_emp"]) + with self.subTest("test 'n_exp' equality"): + np.testing.assert_allclose( + actual=ue_dict_online["n_exp"], + desired=ue_dict_offline["n_exp"], + atol=tol_dict["atol_n_exp"], + rtol=tol_dict["rtol_n_exp"]) + with self.subTest("test 'rate_avg' equality"): + np.testing.assert_allclose( + actual=ue_dict_online["rate_avg"].magnitude, + desired=ue_dict_offline["rate_avg"].magnitude, + atol=tol_dict["atol_rate_avg"], rtol=tol_dict["rtol_rate_avg"]) + with self.subTest("test 'input_parameters' equality"): + for key in ue_dict_offline["input_parameters"].keys(): + np.testing.assert_equal( + actual=ue_dict_online["input_parameters"][key], + desired=ue_dict_offline["input_parameters"][key]) + + def _test_unitary_events_analysis_with_real_data( + self, idw_length, method="pass_events_at_initialization", + time_unit=1 * pq.s): + # Fix random seed to guarantee fixed output + random.seed(1224) + + # set relevant variables of this TestCase + n_trials = 36 # determined by real data + TW_length = (2.1 * pq.s).rescale(time_unit) # determined by real data + IDW_length = idw_length.rescale(time_unit) + noise_length = (0. * pq.s).rescale(time_unit) + trigger_events = (np.arange(0., n_trials * 2.1, 2.1) * pq.s).rescale( + time_unit) + n_buffers, length_remainder = _calculate_n_buffers( + n_trials=n_trials, tw_length=TW_length, + noise_length=noise_length, idw_length=IDW_length) + + # load data and extract spiketrains + # 36 trials with 2.1s length and 0s background noise in between trials + spiketrains, neo_st1, neo_st2 = _load_real_data(n_trials=n_trials, + trial_length=TW_length, + time_unit=time_unit) + + # perform standard unitary events analysis + ue_dict = jointJ_window_analysis( + spiketrains, bin_size=(0.005 * pq.s).rescale(time_unit), + winsize=(0.1 * pq.s).rescale(time_unit), + winstep=(0.005 * pq.s).rescale(time_unit), pattern_hash=[3]) + + if method == "pass_events_at_initialization": + init_events = trigger_events + reading_events = np.array([]) * time_unit + elif method == "pass_events_while_buffered_reading": + init_events = np.array([]) * time_unit + reading_events = trigger_events + else: + raise ValueError("Illegal method to pass events!") + + # create instance of OnlineUnitaryEventAnalysis + _last_n_trials = min(self.last_n_trials, len(spiketrains)) + ouea = OnlineUnitaryEventAnalysis( + bw_size=(0.005 * pq.s).rescale(time_unit), + trigger_pre_size=(0. * pq.s).rescale(time_unit), + trigger_post_size=(2.1 * pq.s).rescale(time_unit), + saw_size=(0.1 * pq.s).rescale(time_unit), + saw_step=(0.005 * pq.s).rescale(time_unit), + trigger_events=init_events, + time_unit=time_unit, + save_n_trials=_last_n_trials) + # perform online unitary event analysis + # simulate buffered reading/transport of spiketrains, + # i.e. loop over spiketrain list and call update_ue() + _simulate_buffered_reading(n_buffers=n_buffers, ouea=ouea, st1=neo_st1, + st2=neo_st2, IDW_length=IDW_length, + length_remainder=length_remainder, + events=reading_events) + ue_dict_online = ouea.get_results() + + # assert equality between result dicts of standard and online ue version + self._assert_equality_of_result_dicts( + ue_dict_offline=ue_dict, ue_dict_online=ue_dict_online, + tol_dict_user={}) + + self._assert_equality_of_passed_and_saved_trials( + last_n_trials=_last_n_trials, passed_trials=spiketrains, + saved_trials=ouea.get_all_saved_trials()) + + # visualize results of online and standard UEA for real data + # _visualize_results_of_offline_and_online_uea( + # spiketrains=spiketrains, + # online_trials=ouea.get_all_saved_trials(), + # ue_dict_offline=ue_dict, + # ue_dict_online=ue_dict_online, alpha=0.05) + + return ouea + + def _test_unitary_events_analysis_with_artificial_data( + self, idw_length, method="pass_events_at_initialization", + time_unit=1 * pq.s): + # Fix random seed to guarantee fixed output + random.seed(1224) + + # set relevant variables of this TestCase + n_trials = 40 + TW_length = (1 * pq.s).rescale(time_unit) + noise_length = (1.5 * pq.s).rescale(time_unit) + IDW_length = idw_length.rescale(time_unit) + trigger_events = (np.arange(0., n_trials*2.5, 2.5) * pq.s).rescale( + time_unit) + trigger_pre_size = (0. * pq.s).rescale(time_unit) + trigger_post_size = (1. * pq.s).rescale(time_unit) + n_buffers, length_remainder = _calculate_n_buffers( + n_trials=n_trials, tw_length=TW_length, + noise_length=noise_length, idw_length=IDW_length) + + # create two long random homogeneous poisson spiketrains which represent + # 40 trials with 1s length and 1.5s background noise in between trials + spiketrains, st1_long, st2_long = _generate_spiketrains( + freq=5*pq.Hz, length=(TW_length+noise_length)*n_trials, + trigger_events=trigger_events, + injection_pos=(0.6 * pq.s).rescale(time_unit), + trigger_pre_size=trigger_pre_size, + trigger_post_size=trigger_post_size, + time_unit=time_unit) + + # perform standard unitary event analysis + ue_dict = jointJ_window_analysis( + spiketrains, bin_size=(0.005 * pq.s).rescale(time_unit), + win_size=(0.1 * pq.s).rescale(time_unit), + win_step=(0.005 * pq.s).rescale(time_unit), pattern_hash=[3]) + + if method == "pass_events_at_initialization": + init_events = trigger_events + reading_events = np.array([]) * time_unit + elif method == "pass_events_while_buffered_reading": + init_events = np.array([]) * time_unit + reading_events = trigger_events + else: + raise ValueError("Illegal method to pass events!") + + # create instance of OnlineUnitaryEventAnalysis + _last_n_trials = min(self.last_n_trials, len(spiketrains)) + ouea = OnlineUnitaryEventAnalysis( + bw_size=(0.005 * pq.s).rescale(time_unit), + trigger_pre_size=trigger_pre_size, + trigger_post_size=trigger_post_size, + saw_size=(0.1 * pq.s).rescale(time_unit), + saw_step=(0.005 * pq.s).rescale(time_unit), + trigger_events=init_events, + time_unit=time_unit, + save_n_trials=_last_n_trials) + # perform online unitary event analysis + # simulate buffered reading/transport of spiketrains, + # i.e. loop over spiketrain list and call update_ue() + _simulate_buffered_reading(n_buffers=n_buffers, ouea=ouea, st1=st1_long, + st2=st2_long, IDW_length=IDW_length, + length_remainder=length_remainder, + events=reading_events) + ue_dict_online = ouea.get_results() + + # assert equality between result dicts of standard and online ue version + self._assert_equality_of_result_dicts( + ue_dict_offline=ue_dict, ue_dict_online=ue_dict_online, + tol_dict_user={}) + + self._assert_equality_of_passed_and_saved_trials( + last_n_trials=_last_n_trials, passed_trials=spiketrains, + saved_trials=ouea.get_all_saved_trials()) + + # visualize results of online and standard UEA for artificial data + # _visualize_results_of_offline_and_online_uea( + # spiketrains=spiketrains, + # online_trials=ouea.get_all_saved_trials(), + # ue_dict_offline=ue_dict, + # ue_dict_online=ue_dict_online, alpha=0.01) + + return ouea + + # test: trial window > in-coming data window (TW > IDW) + def test_TW_larger_IDW_artificial_data(self): + """Test, if online UE analysis is correct when the trial window is + larger than the in-coming data window with artificial data.""" + idw_length = ([0.995, 0.8, 0.6, 0.3, 0.1, 0.05]*pq.s).rescale( + self.time_unit) + for idw in idw_length: + with self.subTest(f"IDW = {idw}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw, time_unit=self.time_unit) + self.doCleanups() + + def test_TW_larger_IDW_real_data(self): + """Test, if online UE analysis is correct when the trial window is + larger than the in-coming data window with real data.""" + idw_length = ([2.05, 2., 1.1, 0.8, 0.1, 0.05]*pq.s).rescale( + self.time_unit) + for idw in idw_length: + with self.subTest(f"IDW = {idw}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw, time_unit=self.time_unit) + self.doCleanups() + + # test: trial window = in-coming data window (TW = IDW) + def test_TW_as_large_as_IDW_real_data(self): + """Test, if online UE analysis is correct when the trial window is + as large as the in-coming data window with real data.""" + idw_length = (2.1*pq.s).rescale(self.time_unit) + with self.subTest(f"IDW = {idw_length}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw_length, time_unit=self.time_unit) + self.doCleanups() + + def test_TW_as_large_as_IDW_artificial_data(self): + """Test, if online UE analysis is correct when the trial window is + as large as the in-coming data window with artificial data.""" + idw_length = (1*pq.s).rescale(self.time_unit) + with self.subTest(f"IDW = {idw_length}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw_length, time_unit=self.time_unit) + self.doCleanups() + + # test: trial window < in-coming data window (TW < IDW) + def test_TW_smaller_IDW_artificial_data(self): + """Test, if online UE analysis is correct when the trial window is + smaller than the in-coming data window with artificial data.""" + idw_length = ([1.05, 1.1, 2, 10, 50, 100]*pq.s).rescale(self.time_unit) + for idw in idw_length: + with self.subTest(f"IDW = {idw}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw, time_unit=self.time_unit) + self.doCleanups() + + def test_TW_smaller_IDW_real_data(self): + """Test, if online UE analysis is correct when the trial window is + smaller than the in-coming data window with real data.""" + idw_length = ([2.15, 2.2, 3, 10, 50, 75.6]*pq.s).rescale(self.time_unit) + for idw in idw_length: + with self.subTest(f"IDW = {idw}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw, time_unit=self.time_unit) + self.doCleanups() + + def test_pass_trigger_events_while_buffered_reading_real_data(self): + idw_length = (2.1*pq.s).rescale(self.time_unit) + with self.subTest(f"IDW = {idw_length}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw_length, + method="pass_events_while_buffered_reading", + time_unit=self.time_unit) + self.doCleanups() + + def test_pass_trigger_events_while_buffered_reading_artificial_data(self): + idw_length = (1*pq.s).rescale(self.time_unit) + with self.subTest(f"IDW = {idw_length}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw_length, + method="pass_events_while_buffered_reading", + time_unit=self.time_unit) + self.doCleanups() + + def test_reset(self): + idw_length = (2.1*pq.s).rescale(self.time_unit) + with self.subTest(f"IDW = {idw_length}"): + ouea = self._test_unitary_events_analysis_with_real_data( + idw_length=idw_length, time_unit=self.time_unit) + self.doCleanups() + # do reset with default parameters + ouea.reset() + # check all class attributes + with self.subTest(f"check 'time_unit'"): + self.assertEqual(ouea.time_unit, 1*pq.s) + with self.subTest(f"check 'data_available_in_mv'"): + self.assertEqual(ouea.data_available_in_mv, None) + with self.subTest(f"check 'waiting_for_new_trigger'"): + self.assertEqual(ouea.waiting_for_new_trigger, True) + with self.subTest(f"check 'trigger_events_left_over'"): + self.assertEqual(ouea.trigger_events_left_over, True) + with self.subTest(f"check 'bw_size'"): + self.assertEqual(ouea.bw_size, 0.005 * pq.s) + with self.subTest(f"check 'trigger_events'"): + self.assertEqual(ouea.trigger_events, []) + with self.subTest(f"check 'trigger_pre_size'"): + self.assertEqual(ouea.trigger_pre_size, 0.5 * pq.s) + with self.subTest(f"check 'trigger_post_size'"): + self.assertEqual(ouea.trigger_post_size, 0.5 * pq.s) + with self.subTest(f"check 'saw_size'"): + self.assertEqual(ouea.saw_size, 0.1 * pq.s) + with self.subTest(f"check 'saw_step'"): + self.assertEqual(ouea.saw_step, 0.005 * pq.s) + with self.subTest(f"check 'n_neurons'"): + self.assertEqual(ouea.n_neurons, 2) + with self.subTest(f"check 'pattern_hash'"): + self.assertEqual(ouea.pattern_hash, [3]) + with self.subTest(f"check 'mw'"): + np.testing.assert_equal(ouea.mw, [[] for _ in range(2)]) + with self.subTest(f"check 'tw_size'"): + self.assertEqual(ouea.tw_size, 1 * pq.s) + with self.subTest(f"check 'tw'"): + np.testing.assert_equal(ouea.tw, [[] for _ in range(2)]) + with self.subTest(f"check 'tw_counter'"): + self.assertEqual(ouea.tw_counter, 0) + with self.subTest(f"check 'n_bins'"): + self.assertEqual(ouea.n_bins, None) + with self.subTest(f"check 'bw'"): + self.assertEqual(ouea.bw, None) + with self.subTest(f"check 'saw_pos_counter'"): + self.assertEqual(ouea.saw_pos_counter, 0) + with self.subTest(f"check 'n_windows'"): + self.assertEqual(ouea.n_windows, 181) + with self.subTest(f"check 'n_trials'"): + self.assertEqual(ouea.n_trials, 0) + with self.subTest(f"check 'n_hashes'"): + self.assertEqual(ouea.n_hashes, 1) + with self.subTest(f"check 'method'"): + self.assertEqual(ouea.method, 'analytic_TrialByTrial') + with self.subTest(f"check 'n_surrogates'"): + self.assertEqual(ouea.n_surrogates, 100) + with self.subTest(f"check 'input_parameters'"): + self.assertEqual(ouea.input_parameters["pattern_hash"], [3]) + self.assertEqual(ouea.input_parameters["bin_size"], 5 * pq.ms) + self.assertEqual(ouea.input_parameters["win_size"], 100 * pq.ms) + self.assertEqual(ouea.input_parameters["win_step"], 5 * pq.ms) + self.assertEqual(ouea.input_parameters["method"], + 'analytic_TrialByTrial') + self.assertEqual(ouea.input_parameters["t_start"], 0 * pq.s) + self.assertEqual(ouea.input_parameters["t_stop"], 1 * pq.s) + self.assertEqual(ouea.input_parameters["n_surrogates"], 100) + with self.subTest(f"check 'Js'"): + np.testing.assert_equal(ouea.Js, np.zeros((181, 1), + dtype=np.float64)) + with self.subTest(f"check 'n_exp'"): + np.testing.assert_equal(ouea.n_exp, np.zeros((181, 1), + dtype=np.float64)) + with self.subTest(f"check 'n_emp'"): + np.testing.assert_equal(ouea.n_emp, np.zeros((181, 1), + dtype=np.float64)) + with self.subTest(f"check 'rate_avg'"): + np.testing.assert_equal(ouea.rate_avg, np.zeros((181, 1, 2), + dtype=np.float64)) + with self.subTest(f"check 'indices'"): + np.testing.assert_equal(ouea.indices, defaultdict(list)) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 28aedfe6b9343e0bb6e12d25913d147b0d939b7c Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Wed, 1 Jun 2022 10:57:49 +0200 Subject: [PATCH 02/16] added new line to the end of the files --- elephant/online.py | 2 +- elephant/test/test_online.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index 097b05ed2..b21cf7a1d 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -735,4 +735,4 @@ def reset(self, bw_size=0.005 * pq.s, trigger_events=None, """ self.__init__(bw_size, trigger_events, trigger_pre_size, trigger_post_size, saw_size, saw_step, n_neurons, - pattern_hash, time_unit) \ No newline at end of file + pattern_hash, time_unit) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index dc419b5bb..2f2d9142a 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -546,4 +546,4 @@ def test_reset(self): if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From 2b5fb87da05851b37338cea3fc33a25ba8d6ec65 Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Wed, 1 Jun 2022 16:39:13 +0200 Subject: [PATCH 03/16] download real spike-date once in the 'setUpClass()' of the 'OnlineUnitaryEventAnalysis' TestCase and load it then several times for the respective test-functions which use this dataset; uncommented import/usage of matplotlib within this TestCase (the created plots are just a possible debug help for the developer, but not needed for executing the test-functions) --- elephant/test/test_online.py | 22 ++++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index 2f2d9142a..bd49707dc 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -2,7 +2,7 @@ import unittest from collections import defaultdict -import matplotlib.pyplot as plt +# import matplotlib.pyplot as plt import neo import numpy as np import quantities as pq @@ -63,7 +63,7 @@ def _visualize_results_of_offline_and_online_uea( viziphant.unitary_event_analysis.plot_ue( spiketrains, Js_dict=ue_dict_offline, significance_level=alpha, unit_real_ids=['1', '2'], suptitle="offline") - plt.show() + # plt.show() # reorder and rename indices-dict of ue_dict_online, if only the last # n-trials were saved; indices-entries of unused trials are overwritten if len(online_trials) < len(spiketrains): @@ -74,7 +74,7 @@ def _visualize_results_of_offline_and_online_uea( viziphant.unitary_event_analysis.plot_ue( online_trials, Js_dict=ue_dict_online, significance_level=alpha, unit_real_ids=['1', '2'], suptitle="online") - plt.show() + # plt.show() def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, @@ -110,10 +110,7 @@ def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, # plt.savefig(f"plots/timelapse_UE/ue_real_data_buff_{i}.pdf") -def _load_real_data(n_trials, trial_length, time_unit): - # download data - repo_path = 'tutorials/tutorial_unitary_event_analysis/data/dataset-1.nix' - filepath = download_datasets(repo_path) +def _load_real_data(filepath, n_trials, trial_length, time_unit): # load data and extract spiketrains io = neo.io.NixIO(f"{filepath}", 'ro') block = io.read_block() @@ -161,6 +158,11 @@ def setUpClass(cls): cls.time_unit = 1 * pq.ms cls.last_n_trials = 50 + # download real data once and load it several times later + cls.repo_path = 'tutorials/tutorial_unitary_event_analysis/' \ + 'data/dataset-1.nix' + cls.filepath = download_datasets(cls.repo_path) + def setUp(self): pass @@ -242,9 +244,9 @@ def _test_unitary_events_analysis_with_real_data( # load data and extract spiketrains # 36 trials with 2.1s length and 0s background noise in between trials - spiketrains, neo_st1, neo_st2 = _load_real_data(n_trials=n_trials, - trial_length=TW_length, - time_unit=time_unit) + spiketrains, neo_st1, neo_st2 = _load_real_data( + filepath=self.filepath, n_trials=n_trials, trial_length=TW_length, + time_unit=time_unit) # perform standard unitary events analysis ue_dict = jointJ_window_analysis( From a5c1b51f1ae20f74800b8e8efb9d56cc6608d979 Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Wed, 1 Jun 2022 16:55:26 +0200 Subject: [PATCH 04/16] fixed too long lines to meet PEP8 rules --- elephant/online.py | 62 +++++++++++++++++++----------------- elephant/test/test_online.py | 13 ++++---- 2 files changed, 39 insertions(+), 36 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index b21cf7a1d..240a7fb7f 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -39,8 +39,8 @@ class OnlineUnitaryEventAnalysis: on the trial segments. It advances with a step size defined by 'saw_step'. saw_step : pq.Quantity - Size of the step which is used to advance the sliding analysis window to - its next position / next trial segment to analyze. + Size of the step which is used to advance the sliding analysis window + to its next position / next trial segment to analyze. n_neurons : int Number of neurons which are analyzed with the UEA. pattern_hash : int or list of int or None @@ -53,8 +53,8 @@ class OnlineUnitaryEventAnalysis: This time unit is used for all calculations which requires a time quantity. (Default: [s]) save_n_trials : (positive) int - The number of trials `n` which will be saved after their analysis with a - queue following the FIFO strategy (first in, first out), i.e. only + The number of trials `n` which will be saved after their analysis with + a queue following the FIFO strategy (first in, first out), i.e. only the last `n` analyzed trials will be stored. (default: None) Attributes @@ -67,13 +67,13 @@ class OnlineUnitaryEventAnalysis: Reflects the status of the updating-algorithm in 'update_uea()'. It is `True`, when the algorithm is in the state of pre- / post-trial analysis, i.e. it expects the arrival of the next trigger event which - will define the next trial. Otherwise, it is `False`, when the algorithm - is within the analysis of the current trial, i.e. it does not need - the next trigger event at this moment. + will define the next trial. Otherwise, it is `False`, when the + algorithm is within the analysis of the current trial, i.e. it does + not need the next trigger event at this moment. trigger_events_left_over : boolean - Reflects the status of the trial defining events in the 'trigger_events' - list. It is `True`, when there are events left, which were not analyzed - yet. Otherwise, it is `False`. + Reflects the status of the trial defining events in the + 'trigger_events' list. It is `True`, when there are events left, which + were not analyzed yet. Otherwise, it is `False`. mw : list of lists Contains for each neuron the spikes which are currently available in the memory window. @@ -133,7 +133,8 @@ class OnlineUnitaryEventAnalysis: * 0-axis --> different window * 1-axis --> different pattern hash n_emp : np.ndarray - The empirical number of coincidences of each pattern within each window. + The empirical number of coincidences of each pattern within each + window. * 0-axis --> different window * 1-axis --> different pattern hash rate_avg : np.ndarray @@ -220,7 +221,7 @@ def __init__(self, bw_size=0.005 * pq.s, trigger_events=None, self.save_n_trials = save_n_trials # initialize helper variables for the memory window (mw) - self.mw = [[] for _ in range(self.n_neurons)] # array of all spiketimes + self.mw = [[] for _ in range(self.n_neurons)] # list of all spiketimes # initialize helper variables for the trial window (tw) self.tw_size = self.trigger_pre_size + self.trigger_post_size @@ -256,10 +257,11 @@ def __init__(self, bw_size=0.005 * pq.s, trigger_events=None, t_stop=self.tw_size, n_surrogates=self.n_surrogates) - # initialize the intermediate result arrays for the joint surprise (js), - # number of expected coincidences (n_exp), number of empirically found - # coincidences (n_emp), rate average of the analyzed neurons (rate_avg), - # as well as the indices of the saw position where coincidences appear + # initialize the intermediate result arrays for the joint surprise + # (js), number of expected coincidences (n_exp), number of empirically + # found coincidences (n_emp), rate average of the analyzed neurons + # (rate_avg), as well as the indices of the saw position where + # coincidences appear self.Js, self.n_exp, self.n_emp = np.zeros( (3, self.n_windows, self.n_hashes), dtype=np.float64) self.rate_avg = np.zeros( @@ -427,7 +429,8 @@ def _check_tw_overlap(self, current_trigger_event, next_trigger_event): current_trigger_event : pq.Quantity Time point around which the current / precedent trial was defined. next_trigger_event : pq.Quantity - Time point around which the next / subsequent trial will be defined. + Time point around which the next / subsequent trial will be + defined. Returns ------- @@ -462,7 +465,8 @@ def _apply_bw_to_tw(self): bin_size=self.bw_size) self.bw = bs.to_bool_array() - def _set_saw_positions(self, t_start, t_stop, win_size, win_step, bin_size): + def _set_saw_positions(self, t_start, t_stop, win_size, win_step, + bin_size): """ Set positions of the sliding analysis window (SAW). @@ -495,11 +499,11 @@ def _set_saw_positions(self, t_start, t_stop, win_size, win_step, bin_size): self.winsize_bintime = _bintime(win_size, bin_size) self.winstep_bintime = _bintime(win_step, bin_size) if self.winsize_bintime * bin_size != win_size: - warnings.warn(f"The ratio between the win_size ({win_size}) and the" - f" bin_size ({bin_size}) is not an integer") + warnings.warn(f"The ratio between the win_size ({win_size}) and " + f"the bin_size ({bin_size}) is not an integer") if self.winstep_bintime * bin_size != win_step: - warnings.warn(f"The ratio between the win_step ({win_step}) and the" - f" bin_size ({bin_size}) is not an integer") + warnings.warn(f"The ratio between the win_step ({win_step}) and " + f"the bin_size ({bin_size}) is not an integer") def _move_saw_over_tw(self, t_stop_idw): """ @@ -539,9 +543,7 @@ def _move_saw_over_tw(self, t_stop_idw): if p_realtime + self.saw_size <= t_stop_idw: # saw is filled mat_win = np.zeros((1, self.n_neurons, self.winsize_bintime)) n_bins_in_current_saw = self.bw[ - :, - p_bintime:p_bintime + self.winsize_bintime].shape[ - 1] + :, p_bintime:p_bintime + self.winsize_bintime].shape[1] if n_bins_in_current_saw < self.winsize_bintime: mat_win[0] += np.pad( self.bw[:, p_bintime:p_bintime + self.winsize_bintime], @@ -575,7 +577,7 @@ def _move_saw_over_tw(self, t_stop_idw): _trial_start = 0 * pq.s _trial_stop = self.tw_size _offset = self.trigger_events[self.tw_counter] - \ - self.trigger_pre_size + self.trigger_pre_size normalized_spike_times = [] for n in range(self.n_neurons): normalized_spike_times.append( @@ -605,8 +607,8 @@ def update_uea(self, spiketrains, events=None): 'spiketrains' into the memory window (MW) and adds also the new trigger 'events' into the 'trigger_events' list. Then depending on the state in which the algorithm is, it processes the new 'spiketrains' respectivly. - There are two major states with each two substates between the algorithm - is switching. + There are two major states with each two substates between the + algorithm is switching. Warns ----- @@ -698,8 +700,8 @@ def update_uea(self, spiketrains, events=None): current_trigger_event=current_trigger_event, next_trigger_event=next_trigger_event): warnings.warn( - f"Data in trial {self.tw_counter} will be analysed " - f"twice! Adjust the trigger events and/or " + f"Data in trial {self.tw_counter} will be analysed" + f" twice! Adjust the trigger events and/or " f"the trial window size.", UserWarning) else: # no overlap exists pass diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index bd49707dc..afefd8d91 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -283,7 +283,7 @@ def _test_unitary_events_analysis_with_real_data( events=reading_events) ue_dict_online = ouea.get_results() - # assert equality between result dicts of standard and online ue version + # assert equality between result dicts of standard / online ue version self._assert_equality_of_result_dicts( ue_dict_offline=ue_dict, ue_dict_online=ue_dict_online, tol_dict_user={}) @@ -320,7 +320,7 @@ def _test_unitary_events_analysis_with_artificial_data( n_trials=n_trials, tw_length=TW_length, noise_length=noise_length, idw_length=IDW_length) - # create two long random homogeneous poisson spiketrains which represent + # create two long random homogeneous poisson spiketrains representing # 40 trials with 1s length and 1.5s background noise in between trials spiketrains, st1_long, st2_long = _generate_spiketrains( freq=5*pq.Hz, length=(TW_length+noise_length)*n_trials, @@ -359,13 +359,13 @@ def _test_unitary_events_analysis_with_artificial_data( # perform online unitary event analysis # simulate buffered reading/transport of spiketrains, # i.e. loop over spiketrain list and call update_ue() - _simulate_buffered_reading(n_buffers=n_buffers, ouea=ouea, st1=st1_long, - st2=st2_long, IDW_length=IDW_length, + _simulate_buffered_reading(n_buffers=n_buffers, ouea=ouea, st1=st1_long + , st2=st2_long, IDW_length=IDW_length, length_remainder=length_remainder, events=reading_events) ue_dict_online = ouea.get_results() - # assert equality between result dicts of standard and online ue version + # assert equality between result dicts of standard / online ue version self._assert_equality_of_result_dicts( ue_dict_offline=ue_dict, ue_dict_online=ue_dict_online, tol_dict_user={}) @@ -439,7 +439,8 @@ def test_TW_smaller_IDW_artificial_data(self): def test_TW_smaller_IDW_real_data(self): """Test, if online UE analysis is correct when the trial window is smaller than the in-coming data window with real data.""" - idw_length = ([2.15, 2.2, 3, 10, 50, 75.6]*pq.s).rescale(self.time_unit) + idw_length = ([2.15, 2.2, 3, 10, 50, 75.6]*pq.s).rescale( + self.time_unit) for idw in idw_length: with self.subTest(f"IDW = {idw}"): self._test_unitary_events_analysis_with_real_data( From 9027049ccfec1268fee5c04ec0dd3e8b0ce44d3b Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Fri, 10 Jun 2022 09:42:31 +0200 Subject: [PATCH 05/16] added missing Parameters descriptions in the docstrings for several methods --- elephant/online.py | 38 ++++++++++++++++++++++++++++++++------ 1 file changed, 32 insertions(+), 6 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index 240a7fb7f..a947f9629 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -473,7 +473,21 @@ def _set_saw_positions(self, t_start, t_stop, win_size, win_step, Determines the positions of the sliding analysis window with respect to the used window size 'win_size' and the advancing step 'win_step'. Also converts this time points into bin-units, i.e. into multiple of the - 'bin_size' which facilitates indexing in upcoming steps. + 'bin_size' which facilitates indexing in upcoming analysis steps. + + Parameters + ---------- + t_start : pq.Quantity + Time point at which the current trial starts. + t_stop : pq.Quantity + Time point at which the current trial ends. + win_size : pq.Quantity + Temporal length of the sliding analysis window. + win_step : pq.Quantity + Temporal size of the advancing step of the sliding analysis window. + bin_size : pq.Quantity + Temporal length of the histogram bins, which were used to bin + the 'spiketrains' in '_apply_bw_tw()'. Warns ----- @@ -517,6 +531,11 @@ def _move_saw_over_tw(self, t_stop_idw): The respective results are then used to update the class attributes 'n_emp', 'n_exp', 'rate_avg' and 'indices'. + Parameters + ---------- + t_stop_idw : pq.Quantity + Time point at which the current incoming data window (IDW) ends. + Notes ----- The 'Js' attribute is not continuously updated, because the @@ -604,11 +623,18 @@ def update_uea(self, spiketrains, events=None): the incoming data window (IDW). This method orchestrates the updating process. It saves the incoming - 'spiketrains' into the memory window (MW) and adds also the new trigger - 'events' into the 'trigger_events' list. Then depending on the state in - which the algorithm is, it processes the new 'spiketrains' respectivly. - There are two major states with each two substates between the - algorithm is switching. + 'spiketrains' into the memory window (MW) and adds also the new + trigger 'events' into the 'trigger_events' list. Then depending on + the state in which the algorithm is, it processes the new + 'spiketrains' respectively. There are two major states with each two + substates between the algorithm is switching. + + Parameters + ---------- + spiketrains : list of neo.SpikeTrain objects + Spike times of the analysed neurons. + events : list of pq.Quantity + Time points of the trial defining trigger events. Warns ----- From 1a50fc8273b503ca8f8454ebecbeb99f6bacbb57 Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Tue, 21 Jun 2022 11:03:16 +0200 Subject: [PATCH 06/16] added missing method-parameters of 'reset()' into the docstring and to the call of '__init__()'; changed order of class methods: public methods are now first in alphabetical order, followed by the private helper methods which are unordered; changed order of testing the class attributes after performing the 'reset()' method in the test-function 'test_reset' to be the same as in the docstring of the 'OnlineUnitaryEventAnalysis' class --- elephant/online.py | 306 ++++++++++++++++++----------------- elephant/test/test_online.py | 18 ++- 2 files changed, 165 insertions(+), 159 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index a947f9629..3c64e1f7c 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -60,9 +60,9 @@ class OnlineUnitaryEventAnalysis: Attributes ---------- data_available_in_mv : boolean - Reflects the status of spike trains in the memory window. It is True, + Reflects the status of spike trains in the memory window. It is `True', when spike trains are in the memory window which were not yet analyzed. - Otherwise, it is False. + Otherwise, it is `False'. waiting_for_new_trigger : boolean Reflects the status of the updating-algorithm in 'update_uea()'. It is `True`, when the algorithm is in the state of pre- / post-trial @@ -164,9 +164,13 @@ class OnlineUnitaryEventAnalysis: Updates the entries of the result dictionary by processing the new arriving 'spiketrains' and trial defining trigger 'events'. reset(bw_size, trigger_events, trigger_pre_size, trigger_post_size, - saw_size, saw_step, n_neurons, pattern_hash) + saw_size, saw_step, n_neurons, pattern_hash, time_unit, + save_n_trials) Resets all class attributes to their initial (default) value. It is actually a re-initialization which allows parameter adjustments. + get_all_saved_trials() + Returns the last 'n'-trials which were analyzed according to the FIFO + strategy (first in, first out). Returns ------- @@ -326,6 +330,154 @@ def get_results(self): np.float32), 'input_parameters': self.input_parameters} + def reset(self, bw_size=0.005 * pq.s, trigger_events=None, + trigger_pre_size=0.5 * pq.s, trigger_post_size=0.5 * pq.s, + saw_size=0.1 * pq.s, saw_step=0.005 * pq.s, n_neurons=2, + pattern_hash=None, time_unit=1 * pq.s, save_n_trials=None): + """ + Resets all class attributes to their initial value. + + This reset is actually a re-initialization which allows parameter + adjustments, so that one instance of 'OnlineUnitaryEventAnalysis' can + be flexibly adjusted to changing experimental circumstances. + + Parameters + ---------- + (same as for the constructor; see docstring of constructor for details) + + """ + self.__init__(bw_size, trigger_events, trigger_pre_size, + trigger_post_size, saw_size, saw_step, n_neurons, + pattern_hash, time_unit, save_n_trials) + + def update_uea(self, spiketrains, events=None): + """ + Update unitary event analysis (UEA) with new arriving spike data from + the incoming data window (IDW). + + This method orchestrates the updating process. It saves the incoming + 'spiketrains' into the memory window (MW) and adds also the new + trigger 'events' into the 'trigger_events' list. Then depending on + the state in which the algorithm is, it processes the new + 'spiketrains' respectively. There are two major states with each two + substates between the algorithm is switching. + + Parameters + ---------- + spiketrains : list of neo.SpikeTrain objects + Spike times of the analysed neurons. + events : list of pq.Quantity + Time points of the trial defining trigger events. + + Warns + ----- + UserWarning + * if an overlap between successive trials exists, spike data + of these trials will be analysed twice. The user should adjust + the trigger events and/or the trial window size to increase + the interval between successive trials to avoid an overlap. + + Notes + ----- + Short summary of the different algorithm major states / substates: + 1. pre/post trial analysis: algorithm is waiting for IDW with + new trigger event + 1.1. IDW contains new trigger event + 1.2. IDW does not contain new trigger event + 2. within trial analysis: algorithm is waiting for IDW with + spikes of current trial + 2.1. IDW contains new trigger event + 2.2. IDW does not contain new trigger event, it just has new spikes + of the current trial + + """ + # rescale spiketrains to time_unit + spiketrains = [st.rescale(self.time_unit) + if st.t_start.units == st.units == st.t_stop + else st.rescale(st.units).rescale(self.time_unit) + for st in spiketrains] + + if events is None: + events = np.array([]) + if len(events) > 0: + for event in events: + if event not in self.trigger_events: + self.trigger_events.append(event.rescale(self.time_unit)) + self.trigger_events.sort() + self.n_trials = len(self.trigger_events) + # save incoming spikes (IDW) into memory (MW) + self._save_idw_into_mw(spiketrains) + # extract relevant time information + idw_t_start = spiketrains[0].t_start + idw_t_stop = spiketrains[0].t_stop + + # analyse all trials which are available in the memory + self.data_available_in_mv = True + while self.data_available_in_mv: + if self.tw_counter == self.n_trials: + break + if self.n_trials == 0: + current_trigger_event = np.inf * self.time_unit + next_trigger_event = np.inf * self.time_unit + else: + current_trigger_event = self.trigger_events[self.tw_counter] + if self.tw_counter <= self.n_trials - 2: + next_trigger_event = self.trigger_events[ + self.tw_counter + 1] + else: + next_trigger_event = np.inf * self.time_unit + + # # case 1: pre/post trial analysis, + # i.e. waiting for IDW with new trigger event + if self.waiting_for_new_trigger: + # # subcase 1: IDW contains trigger event + if (idw_t_start <= current_trigger_event) & \ + (current_trigger_event <= idw_t_stop): + self.waiting_for_new_trigger = False + if self.trigger_events_left_over: + # define TW around trigger event + self._define_tw(trigger_event=current_trigger_event) + # apply BW to available data in TW + self._apply_bw_to_tw() + # move SAW over available data in TW + self._move_saw_over_tw(t_stop_idw=idw_t_stop) + else: + pass + # # subcase 2: IDW does not contain trigger event + else: + self._move_mw( + new_t_start=idw_t_stop - self.trigger_pre_size) + + # # Case 2: within trial analysis, + # i.e. waiting for new IDW with spikes of current trial + else: + # # Subcase 3: IDW contains new trigger event + if (idw_t_start <= next_trigger_event) & \ + (next_trigger_event <= idw_t_stop): + # check if an overlap between current / next trial exists + if self._check_tw_overlap( + current_trigger_event=current_trigger_event, + next_trigger_event=next_trigger_event): + warnings.warn( + f"Data in trial {self.tw_counter} will be analysed" + f" twice! Adjust the trigger events and/or " + f"the trial window size.", UserWarning) + else: # no overlap exists + pass + # # Subcase 4: IDW does not contain trigger event, + # i.e. just new spikes of the current trial + else: + pass + if self.trigger_events_left_over: + # define trial TW around trigger event + self._define_tw(trigger_event=current_trigger_event) + # apply BW to available data in TW + self._apply_bw_to_tw() + # move SAW over available data in TW + self._move_saw_over_tw(t_stop_idw=idw_t_stop) + else: + pass + def _pval(self, n_emp, n_exp): """ Calculates the probability of detecting 'n_emp' or more coincidences @@ -616,151 +768,3 @@ def _move_saw_over_tw(self, t_stop_idw): self.trigger_events_left_over = False self.data_available_in_mv = False print(f"tw_counter = {self.tw_counter}") # DEBUG-aid - - def update_uea(self, spiketrains, events=None): - """ - Update unitary event analysis (UEA) with new arriving spike data from - the incoming data window (IDW). - - This method orchestrates the updating process. It saves the incoming - 'spiketrains' into the memory window (MW) and adds also the new - trigger 'events' into the 'trigger_events' list. Then depending on - the state in which the algorithm is, it processes the new - 'spiketrains' respectively. There are two major states with each two - substates between the algorithm is switching. - - Parameters - ---------- - spiketrains : list of neo.SpikeTrain objects - Spike times of the analysed neurons. - events : list of pq.Quantity - Time points of the trial defining trigger events. - - Warns - ----- - UserWarning - * if an overlap between successive trials exists, spike data - of these trials will be analysed twice. The user should adjust - the trigger events and/or the trial window size to increase - the interval between successive trials to avoid an overlap. - - Notes - ----- - Short summary of the different algorithm major states / substates: - 1. pre/post trial analysis: algorithm is waiting for IDW with - new trigger event - 1.1. IDW contains new trigger event - 1.2. IDW does not contain new trigger event - 2. within trial analysis: algorithm is waiting for IDW with - spikes of current trial - 2.1. IDW contains new trigger event - 2.2. IDW does not contain new trigger event, it just has new spikes - of the current trial - - """ - # rescale spiketrains to time_unit - spiketrains = [st.rescale(self.time_unit) - if st.t_start.units == st.units == st.t_stop - else st.rescale(st.units).rescale(self.time_unit) - for st in spiketrains] - - if events is None: - events = np.array([]) - if len(events) > 0: - for event in events: - if event not in self.trigger_events: - self.trigger_events.append(event.rescale(self.time_unit)) - self.trigger_events.sort() - self.n_trials = len(self.trigger_events) - # save incoming spikes (IDW) into memory (MW) - self._save_idw_into_mw(spiketrains) - # extract relevant time information - idw_t_start = spiketrains[0].t_start - idw_t_stop = spiketrains[0].t_stop - - # analyse all trials which are available in the memory - self.data_available_in_mv = True - while self.data_available_in_mv: - if self.tw_counter == self.n_trials: - break - if self.n_trials == 0: - current_trigger_event = np.inf * self.time_unit - next_trigger_event = np.inf * self.time_unit - else: - current_trigger_event = self.trigger_events[self.tw_counter] - if self.tw_counter <= self.n_trials - 2: - next_trigger_event = self.trigger_events[ - self.tw_counter + 1] - else: - next_trigger_event = np.inf * self.time_unit - - # # case 1: pre/post trial analysis, - # i.e. waiting for IDW with new trigger event - if self.waiting_for_new_trigger: - # # subcase 1: IDW contains trigger event - if (idw_t_start <= current_trigger_event) & \ - (current_trigger_event <= idw_t_stop): - self.waiting_for_new_trigger = False - if self.trigger_events_left_over: - # define TW around trigger event - self._define_tw(trigger_event=current_trigger_event) - # apply BW to available data in TW - self._apply_bw_to_tw() - # move SAW over available data in TW - self._move_saw_over_tw(t_stop_idw=idw_t_stop) - else: - pass - # # subcase 2: IDW does not contain trigger event - else: - self._move_mw( - new_t_start=idw_t_stop - self.trigger_pre_size) - - # # Case 2: within trial analysis, - # i.e. waiting for new IDW with spikes of current trial - else: - # # Subcase 3: IDW contains new trigger event - if (idw_t_start <= next_trigger_event) & \ - (next_trigger_event <= idw_t_stop): - # check if an overlap between current / next trial exists - if self._check_tw_overlap( - current_trigger_event=current_trigger_event, - next_trigger_event=next_trigger_event): - warnings.warn( - f"Data in trial {self.tw_counter} will be analysed" - f" twice! Adjust the trigger events and/or " - f"the trial window size.", UserWarning) - else: # no overlap exists - pass - # # Subcase 4: IDW does not contain trigger event, - # i.e. just new spikes of the current trial - else: - pass - if self.trigger_events_left_over: - # define trial TW around trigger event - self._define_tw(trigger_event=current_trigger_event) - # apply BW to available data in TW - self._apply_bw_to_tw() - # move SAW over available data in TW - self._move_saw_over_tw(t_stop_idw=idw_t_stop) - else: - pass - - def reset(self, bw_size=0.005 * pq.s, trigger_events=None, - trigger_pre_size=0.5 * pq.s, trigger_post_size=0.5 * pq.s, - saw_size=0.1 * pq.s, saw_step=0.005 * pq.s, n_neurons=2, - pattern_hash=None, time_unit=1 * pq.s): - """ - Resets all class attributes to their initial value. - - This reset is actually a re-initialization which allows parameter - adjustments, so that one instance of 'OnlineUnitaryEventAnalysis' can - be flexibly adjusted to changing experimental circumstances. - - Parameters - ---------- - (same as for the constructor; see docstring of constructor for details) - - """ - self.__init__(bw_size, trigger_events, trigger_pre_size, - trigger_post_size, saw_size, saw_step, n_neurons, - pattern_hash, time_unit) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index afefd8d91..11117f5a1 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -474,14 +474,6 @@ def test_reset(self): # do reset with default parameters ouea.reset() # check all class attributes - with self.subTest(f"check 'time_unit'"): - self.assertEqual(ouea.time_unit, 1*pq.s) - with self.subTest(f"check 'data_available_in_mv'"): - self.assertEqual(ouea.data_available_in_mv, None) - with self.subTest(f"check 'waiting_for_new_trigger'"): - self.assertEqual(ouea.waiting_for_new_trigger, True) - with self.subTest(f"check 'trigger_events_left_over'"): - self.assertEqual(ouea.trigger_events_left_over, True) with self.subTest(f"check 'bw_size'"): self.assertEqual(ouea.bw_size, 0.005 * pq.s) with self.subTest(f"check 'trigger_events'"): @@ -498,6 +490,16 @@ def test_reset(self): self.assertEqual(ouea.n_neurons, 2) with self.subTest(f"check 'pattern_hash'"): self.assertEqual(ouea.pattern_hash, [3]) + with self.subTest(f"check 'time_unit'"): + self.assertEqual(ouea.time_unit, 1*pq.s) + with self.subTest(f"check 'save_n_trials'"): + self.assertEqual(ouea.save_n_trials, None) + with self.subTest(f"check 'data_available_in_mv'"): + self.assertEqual(ouea.data_available_in_mv, None) + with self.subTest(f"check 'waiting_for_new_trigger'"): + self.assertEqual(ouea.waiting_for_new_trigger, True) + with self.subTest(f"check 'trigger_events_left_over'"): + self.assertEqual(ouea.trigger_events_left_over, True) with self.subTest(f"check 'mw'"): np.testing.assert_equal(ouea.mw, [[] for _ in range(2)]) with self.subTest(f"check 'tw_size'"): From 10caf2e423902481f1caa2fad69298fdb6d64761 Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Thu, 14 Jul 2022 09:04:12 +0200 Subject: [PATCH 07/16] corrected docstring; corrected movment of the memory window --- elephant/online.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index 3c64e1f7c..1f582bfd4 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -507,7 +507,7 @@ def _save_idw_into_mw(self, idw): """ Save in-incoming data window (IDW) into memory window (MW). - This function appends for each neuron all the spikes which are arriving + This method appends for each neuron all the spikes which are arriving with 'idw' into the respective sub-list of 'mv'. Parameters @@ -639,7 +639,7 @@ def _set_saw_positions(self, t_start, t_stop, win_size, win_step, Temporal size of the advancing step of the sliding analysis window. bin_size : pq.Quantity Temporal length of the histogram bins, which were used to bin - the 'spiketrains' in '_apply_bw_tw()'. + the 'spiketrains' in '_apply_bw_to_tw()'. Warns ----- @@ -694,8 +694,8 @@ def _move_saw_over_tw(self, t_stop_idw): joint-surprise is determined just when the user calls 'get_results()'. This is due to the dependency of the distribution from which 'Js' is calculated on the attributes 'n_emp' and 'n_exp'. Updating / changing - 'n_emp' and 'n_exp' changes also this distribution, so that it not any - more possible to simply sum the joint-surprise values of different + 'n_emp' and 'n_exp' changes also this distribution, so that it is not + possible anymore to simply sum the joint-surprise values of different trials at the same sliding analysis window position, because they were based on different distributions. @@ -741,8 +741,8 @@ def _move_saw_over_tw(self, t_stop_idw): if i == self.n_windows - 1: # last SAW position finished self.saw_pos_counter = 0 # move MV after SAW is finished with analysis of one trial - self._move_mw(new_t_start=self.trigger_events[ - self.tw_counter] + self.tw_size) + self._move_mw(new_t_start=self.trigger_events[self.tw_counter] + + self.trigger_post_size) # save analysed trial for visualization if self.save_n_trials: _trial_start = 0 * pq.s From 651f2f97a7a7d7af9d3c55d3093317a23f0d6954 Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Tue, 19 Jul 2022 15:58:33 +0200 Subject: [PATCH 08/16] update_uea() accepts now also list of numpy.ndarray for 'spiketrains' and numpy.ndarray for 'events' parameters, in the later case three additional parameters must be specified, namely: 't_start', 't_stop' and 'time_unit'; extended TestCase class to cover also these alternative types for the 'spiketrains' and 'events' representation --- elephant/online.py | 49 +++++++++---- elephant/test/test_online.py | 135 +++++++++++++++++++++-------------- 2 files changed, 119 insertions(+), 65 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index 1f582bfd4..be110d656 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -350,7 +350,8 @@ def reset(self, bw_size=0.005 * pq.s, trigger_events=None, trigger_post_size, saw_size, saw_step, n_neurons, pattern_hash, time_unit, save_n_trials) - def update_uea(self, spiketrains, events=None): + def update_uea(self, spiketrains, events=None, + t_start=None, t_stop=None, time_unit=None): """ Update unitary event analysis (UEA) with new arriving spike data from the incoming data window (IDW). @@ -364,10 +365,19 @@ def update_uea(self, spiketrains, events=None): Parameters ---------- - spiketrains : list of neo.SpikeTrain objects + spiketrains : list of neo.SpikeTrain or list of numpy.ndarray Spike times of the analysed neurons. - events : list of pq.Quantity + events : list of pq.Quantity or list of numpy.ndarray Time points of the trial defining trigger events. + t_start : float + Start time of the IDW. + Required if 'spiketrains' is a list of numpy.ndarray. + t_stop : float + Stop time of the IDW. + Required if 'spiketrains' is a list of numpy.ndarray. + time_unit : string + Name of the time unit used for representing the spikes in the IDW. + E.g. 's', 'ms'. Warns ----- @@ -391,25 +401,40 @@ def update_uea(self, spiketrains, events=None): of the current trial """ - # rescale spiketrains to time_unit - spiketrains = [st.rescale(self.time_unit) - if st.t_start.units == st.units == st.t_stop - else st.rescale(st.units).rescale(self.time_unit) - for st in spiketrains] + if isinstance(spiketrains[0], neo.SpikeTrain): + # rescale spiketrains to time_unit + spiketrains = [st.rescale(self.time_unit) + if st.t_start.units == st.units == st.t_stop + else st.rescale(st.units).rescale(self.time_unit) + for st in spiketrains] + elif isinstance(spiketrains[0], np.ndarray): + if t_start is None or t_stop is None or time_unit is None: + raise ValueError("'spiketrains' is a list of np.array(), thus" + "'t_start', 't_stop' and 'time_unit' must be" + "specified!") + else: + spiketrains = [neo.SpikeTrain( + times=st, t_start=t_start, t_stop=t_stop, + units=time_unit).rescale(self.time_unit) + for st in spiketrains] + # extract relevant time information + idw_t_start = spiketrains[0].t_start + idw_t_stop = spiketrains[0].t_stop if events is None: events = np.array([]) if len(events) > 0: for event in events: - if event not in self.trigger_events: + if isinstance(events, np.ndarray): + event = pq.Quantity(event, time_unit) + if event.rescale(self.time_unit) not in self.trigger_events: self.trigger_events.append(event.rescale(self.time_unit)) self.trigger_events.sort() self.n_trials = len(self.trigger_events) + # save incoming spikes (IDW) into memory (MW) self._save_idw_into_mw(spiketrains) - # extract relevant time information - idw_t_start = spiketrains[0].t_start - idw_t_stop = spiketrains[0].t_stop + # analyse all trials which are available in the memory self.data_available_in_mv = True diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index 11117f5a1..f8b8fed8c 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -78,7 +78,8 @@ def _visualize_results_of_offline_and_online_uea( def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, - length_remainder, events=None): + length_remainder, events=None, + st_type="list_of_neo.SpikeTrain"): if events is None: events = np.array([]) for i in range(n_buffers): @@ -93,14 +94,26 @@ def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, if len(events) > 0: idx_events_in_buffer = (events >= buff_t_start) & \ (events <= buff_t_stop) - events_in_buffer = events[idx_events_in_buffer].tolist() + events_in_buffer = events[idx_events_in_buffer]#.tolist() events = events[np.logical_not(idx_events_in_buffer)] - ouea.update_uea( - spiketrains=[ - st1.time_slice(t_start=buff_t_start, t_stop=buff_t_stop), - st2.time_slice(t_start=buff_t_start, t_stop=buff_t_stop)], - events=events_in_buffer) + if st_type == "list_of_neo.SpikeTrain": + ouea.update_uea( + spiketrains=[ + st1.time_slice(t_start=buff_t_start, t_stop=buff_t_stop), + st2.time_slice(t_start=buff_t_start, t_stop=buff_t_stop)], + events=events_in_buffer) + elif st_type == "list_of_numpy_array": + ouea.update_uea( + spiketrains=[ + st1.time_slice(t_start=buff_t_start, t_stop=buff_t_stop).magnitude, + st2.time_slice(t_start=buff_t_start, t_stop=buff_t_stop).magnitude], + events=events_in_buffer, t_start=buff_t_start, + t_stop=buff_t_stop, time_unit=st1.units) + else: + raise ValueError("undefined type for spiktrains representation! " + "Use either list of neo.SpikeTrains or " + "list of numpy arrays") print(f"#buffer = {i}") # DEBUG-aid # # aid to create timelapses # result_dict = ouea.get_results() @@ -163,6 +176,8 @@ def setUpClass(cls): 'data/dataset-1.nix' cls.filepath = download_datasets(cls.repo_path) + cls.st_types = ["list_of_neo.SpikeTrain", "list_of_numpy_array"] + def setUp(self): pass @@ -227,7 +242,7 @@ def _assert_equality_of_result_dicts(self, ue_dict_offline, ue_dict_online, def _test_unitary_events_analysis_with_real_data( self, idw_length, method="pass_events_at_initialization", - time_unit=1 * pq.s): + time_unit=1 * pq.s, st_type="list_of_neo.SpikeTrain"): # Fix random seed to guarantee fixed output random.seed(1224) @@ -277,10 +292,10 @@ def _test_unitary_events_analysis_with_real_data( # perform online unitary event analysis # simulate buffered reading/transport of spiketrains, # i.e. loop over spiketrain list and call update_ue() - _simulate_buffered_reading(n_buffers=n_buffers, ouea=ouea, st1=neo_st1, - st2=neo_st2, IDW_length=IDW_length, - length_remainder=length_remainder, - events=reading_events) + _simulate_buffered_reading( + n_buffers=n_buffers, ouea=ouea, st1=neo_st1, st2=neo_st2, + IDW_length=IDW_length, length_remainder=length_remainder, + events=reading_events, st_type=st_type) ue_dict_online = ouea.get_results() # assert equality between result dicts of standard / online ue version @@ -303,7 +318,7 @@ def _test_unitary_events_analysis_with_real_data( def _test_unitary_events_analysis_with_artificial_data( self, idw_length, method="pass_events_at_initialization", - time_unit=1 * pq.s): + time_unit=1 * pq.s , st_type="list_of_neo.SpikeTrain"): # Fix random seed to guarantee fixed output random.seed(1224) @@ -359,10 +374,10 @@ def _test_unitary_events_analysis_with_artificial_data( # perform online unitary event analysis # simulate buffered reading/transport of spiketrains, # i.e. loop over spiketrain list and call update_ue() - _simulate_buffered_reading(n_buffers=n_buffers, ouea=ouea, st1=st1_long - , st2=st2_long, IDW_length=IDW_length, - length_remainder=length_remainder, - events=reading_events) + _simulate_buffered_reading( + n_buffers=n_buffers, ouea=ouea, st1=st1_long, st2=st2_long, + IDW_length=IDW_length, length_remainder=length_remainder, + events=reading_events, st_type=st_type) ue_dict_online = ouea.get_results() # assert equality between result dicts of standard / online ue version @@ -390,10 +405,12 @@ def test_TW_larger_IDW_artificial_data(self): idw_length = ([0.995, 0.8, 0.6, 0.3, 0.1, 0.05]*pq.s).rescale( self.time_unit) for idw in idw_length: - with self.subTest(f"IDW = {idw}"): - self._test_unitary_events_analysis_with_artificial_data( - idw_length=idw, time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw, time_unit=self.time_unit, + st_type=st_type) + self.doCleanups() def test_TW_larger_IDW_real_data(self): """Test, if online UE analysis is correct when the trial window is @@ -401,29 +418,35 @@ def test_TW_larger_IDW_real_data(self): idw_length = ([2.05, 2., 1.1, 0.8, 0.1, 0.05]*pq.s).rescale( self.time_unit) for idw in idw_length: - with self.subTest(f"IDW = {idw}"): - self._test_unitary_events_analysis_with_real_data( - idw_length=idw, time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw, time_unit=self.time_unit, + st_type=st_type) + self.doCleanups() # test: trial window = in-coming data window (TW = IDW) def test_TW_as_large_as_IDW_real_data(self): """Test, if online UE analysis is correct when the trial window is as large as the in-coming data window with real data.""" idw_length = (2.1*pq.s).rescale(self.time_unit) - with self.subTest(f"IDW = {idw_length}"): - self._test_unitary_events_analysis_with_real_data( - idw_length=idw_length, time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw_length} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw_length, time_unit=self.time_unit, + st_type=st_type) + self.doCleanups() def test_TW_as_large_as_IDW_artificial_data(self): """Test, if online UE analysis is correct when the trial window is as large as the in-coming data window with artificial data.""" idw_length = (1*pq.s).rescale(self.time_unit) - with self.subTest(f"IDW = {idw_length}"): - self._test_unitary_events_analysis_with_artificial_data( - idw_length=idw_length, time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw_length} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw_length, time_unit=self.time_unit, + st_type=st_type) + self.doCleanups() # test: trial window < in-coming data window (TW < IDW) def test_TW_smaller_IDW_artificial_data(self): @@ -431,10 +454,12 @@ def test_TW_smaller_IDW_artificial_data(self): smaller than the in-coming data window with artificial data.""" idw_length = ([1.05, 1.1, 2, 10, 50, 100]*pq.s).rescale(self.time_unit) for idw in idw_length: - with self.subTest(f"IDW = {idw}"): - self._test_unitary_events_analysis_with_artificial_data( - idw_length=idw, time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw, time_unit=self.time_unit, + st_type=st_type) + self.doCleanups() def test_TW_smaller_IDW_real_data(self): """Test, if online UE analysis is correct when the trial window is @@ -442,28 +467,32 @@ def test_TW_smaller_IDW_real_data(self): idw_length = ([2.15, 2.2, 3, 10, 50, 75.6]*pq.s).rescale( self.time_unit) for idw in idw_length: - with self.subTest(f"IDW = {idw}"): - self._test_unitary_events_analysis_with_real_data( - idw_length=idw, time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw, time_unit=self.time_unit, + st_type=st_type) + self.doCleanups() def test_pass_trigger_events_while_buffered_reading_real_data(self): idw_length = (2.1*pq.s).rescale(self.time_unit) - with self.subTest(f"IDW = {idw_length}"): - self._test_unitary_events_analysis_with_real_data( - idw_length=idw_length, - method="pass_events_while_buffered_reading", - time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw_length} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_real_data( + idw_length=idw_length, + method="pass_events_while_buffered_reading", + time_unit=self.time_unit, st_type=st_type) + self.doCleanups() def test_pass_trigger_events_while_buffered_reading_artificial_data(self): idw_length = (1*pq.s).rescale(self.time_unit) - with self.subTest(f"IDW = {idw_length}"): - self._test_unitary_events_analysis_with_artificial_data( - idw_length=idw_length, - method="pass_events_while_buffered_reading", - time_unit=self.time_unit) - self.doCleanups() + for st_type in self.st_types: + with self.subTest(f"IDW = {idw_length} | st_type: {st_type}"): + self._test_unitary_events_analysis_with_artificial_data( + idw_length=idw_length, + method="pass_events_while_buffered_reading", + time_unit=self.time_unit, st_type=st_type) + self.doCleanups() def test_reset(self): idw_length = (2.1*pq.s).rescale(self.time_unit) From 38299ecc77f8a4714ea255a6a2295458cd8c4d6f Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Tue, 19 Jul 2022 17:37:36 +0200 Subject: [PATCH 09/16] allow now float/numpy.array in __init__ of OnlineUnitaryEventAnalysis parameters which are not pq.Quantities --- elephant/online.py | 28 +++++++-------- elephant/test/test_online.py | 70 ++++++++++++++++++++++++++---------- 2 files changed, 65 insertions(+), 33 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index be110d656..0112aac9b 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -187,10 +187,10 @@ class OnlineUnitaryEventAnalysis: """ - def __init__(self, bw_size=0.005 * pq.s, trigger_events=None, - trigger_pre_size=0.5 * pq.s, trigger_post_size=0.5 * pq.s, - saw_size=0.1 * pq.s, saw_step=0.005 * pq.s, n_neurons=2, - pattern_hash=None, time_unit=1 * pq.s, save_n_trials=None): + def __init__(self, bw_size=0.005, trigger_events=None, + trigger_pre_size=0.5, trigger_post_size=0.5, + saw_size=0.1, saw_step=0.005, n_neurons=2, + pattern_hash=None, time_unit='s', save_n_trials=None): """ Constructor. Initializes all attributes of the new instance. """ @@ -200,22 +200,22 @@ def __init__(self, bw_size=0.005 * pq.s, trigger_events=None, self.trigger_events_left_over = True # save constructor parameters - if time_unit.units != (pq.s and pq.ms): + if time_unit not in ['s', 'ms'] and time_unit not in [pq.s, pq.ms]: warnings.warn(message=f"Unusual time units like {time_unit} can " f"cause numerical imprecise results. " f"Use `ms` or `s` instead!", category=UserWarning) - self.time_unit = time_unit - self.bw_size = bw_size.rescale(self.time_unit) + self.time_unit = pq.Quantity(1, time_unit) + self.bw_size = pq.Quantity(bw_size, self.time_unit) if trigger_events is None: self.trigger_events = [] else: - self.trigger_events = trigger_events.rescale( - self.time_unit).tolist() - self.trigger_pre_size = trigger_pre_size.rescale(self.time_unit) - self.trigger_post_size = trigger_post_size.rescale(self.time_unit) - self.saw_size = saw_size.rescale(self.time_unit) # multiple of bw_size - self.saw_step = saw_step.rescale(self.time_unit) # multiple of bw_size + self.trigger_events = pq.Quantity(trigger_events, + self.time_unit).tolist() + self.trigger_pre_size = pq.Quantity(trigger_pre_size, self.time_unit) + self.trigger_post_size = pq.Quantity(trigger_post_size, self.time_unit) + self.saw_size = pq.Quantity(saw_size, self.time_unit) # multiple of bw_size + self.saw_step = pq.Quantity(saw_step, self.time_unit) # multiple of bw_size self.n_neurons = n_neurons if pattern_hash is None: pattern = [1] * n_neurons @@ -257,7 +257,7 @@ def __init__(self, bw_size=0.005 * pq.s, trigger_events=None, win_size=self.saw_size.rescale(pq.ms), win_step=self.saw_step.rescale(pq.ms), method=self.method, - t_start=0 * time_unit, + t_start=0 * self.time_unit, t_stop=self.tw_size, n_surrogates=self.n_surrogates) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index f8b8fed8c..b6863dda0 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -280,15 +280,31 @@ def _test_unitary_events_analysis_with_real_data( # create instance of OnlineUnitaryEventAnalysis _last_n_trials = min(self.last_n_trials, len(spiketrains)) - ouea = OnlineUnitaryEventAnalysis( - bw_size=(0.005 * pq.s).rescale(time_unit), - trigger_pre_size=(0. * pq.s).rescale(time_unit), - trigger_post_size=(2.1 * pq.s).rescale(time_unit), - saw_size=(0.1 * pq.s).rescale(time_unit), - saw_step=(0.005 * pq.s).rescale(time_unit), - trigger_events=init_events, - time_unit=time_unit, - save_n_trials=_last_n_trials) + ouea = None + if st_type == "list_of_neo.SpikeTrain": + ouea = OnlineUnitaryEventAnalysis( + bw_size=(0.005 * pq.s).rescale(time_unit), + trigger_pre_size=(0. * pq.s).rescale(time_unit), + trigger_post_size=(2.1 * pq.s).rescale(time_unit), + saw_size=(0.1 * pq.s).rescale(time_unit), + saw_step=(0.005 * pq.s).rescale(time_unit), + trigger_events=init_events, + time_unit=time_unit, + save_n_trials=_last_n_trials) + elif st_type == "list_of_numpy_array": + ouea = OnlineUnitaryEventAnalysis( + bw_size=(0.005 * pq.s).rescale(time_unit).magnitude, + trigger_pre_size=(0. * pq.s).rescale(time_unit).magnitude, + trigger_post_size=(2.1 * pq.s).rescale(time_unit).magnitude, + saw_size=(0.1 * pq.s).rescale(time_unit).magnitude, + saw_step=(0.005 * pq.s).rescale(time_unit).magnitude, + trigger_events=init_events.magnitude, + time_unit=time_unit.__str__().split(" ")[1], + save_n_trials=_last_n_trials) + else: + ValueError("undefined type for spiktrains representation! " + "Use either list of neo.SpikeTrains or " + "list of numpy arrays") # perform online unitary event analysis # simulate buffered reading/transport of spiketrains, # i.e. loop over spiketrain list and call update_ue() @@ -318,7 +334,7 @@ def _test_unitary_events_analysis_with_real_data( def _test_unitary_events_analysis_with_artificial_data( self, idw_length, method="pass_events_at_initialization", - time_unit=1 * pq.s , st_type="list_of_neo.SpikeTrain"): + time_unit=1 * pq.s, st_type="list_of_neo.SpikeTrain"): # Fix random seed to guarantee fixed output random.seed(1224) @@ -362,15 +378,31 @@ def _test_unitary_events_analysis_with_artificial_data( # create instance of OnlineUnitaryEventAnalysis _last_n_trials = min(self.last_n_trials, len(spiketrains)) - ouea = OnlineUnitaryEventAnalysis( - bw_size=(0.005 * pq.s).rescale(time_unit), - trigger_pre_size=trigger_pre_size, - trigger_post_size=trigger_post_size, - saw_size=(0.1 * pq.s).rescale(time_unit), - saw_step=(0.005 * pq.s).rescale(time_unit), - trigger_events=init_events, - time_unit=time_unit, - save_n_trials=_last_n_trials) + ouea = None + if st_type == "list_of_neo.SpikeTrain": + ouea = OnlineUnitaryEventAnalysis( + bw_size=(0.005 * pq.s).rescale(time_unit), + trigger_pre_size=trigger_pre_size, + trigger_post_size=trigger_post_size, + saw_size=(0.1 * pq.s).rescale(time_unit), + saw_step=(0.005 * pq.s).rescale(time_unit), + trigger_events=init_events, + time_unit=time_unit, + save_n_trials=_last_n_trials) + elif st_type == "list_of_numpy_array": + ouea = OnlineUnitaryEventAnalysis( + bw_size=(0.005 * pq.s).rescale(time_unit).magnitude, + trigger_pre_size=trigger_pre_size.magnitude, + trigger_post_size=trigger_post_size.magnitude, + saw_size=(0.1 * pq.s).rescale(time_unit).magnitude, + saw_step=(0.005 * pq.s).rescale(time_unit).magnitude, + trigger_events=init_events.magnitude, + time_unit=time_unit.__str__().split(" ")[1], + save_n_trials=_last_n_trials) + else: + ValueError("undefined type for spiktrains representation! " + "Use either list of neo.SpikeTrains or " + "list of numpy arrays") # perform online unitary event analysis # simulate buffered reading/transport of spiketrains, # i.e. loop over spiketrain list and call update_ue() From 776965c5661b1e320479a6338c906036cf85d510 Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Fri, 29 Jul 2022 16:04:48 +0200 Subject: [PATCH 10/16] bug fix: use tw_counter instead of n_trials to scale rate_avg, because tw_counter represents the number of already analyzed trials and n_trials represents the number of trials for which a trigger event was given; tw_counter <= n_trials, i.e. it is possible to provide more trigger events than could be actually analyzed due to the simulation duration (providing trigger events that would occure after the simulation end) --- elephant/online.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/online.py b/elephant/online.py index 0112aac9b..c95b8d7e2 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -316,7 +316,7 @@ def get_results(self): self.n_exp.astype(np.float64)).flatten() self.Js = jointJ(p) self.rate_avg = (self.rate_avg * (self.saw_size / self.bw_size)) / \ - (self.saw_size * self.n_trials) + (self.saw_size * self.tw_counter) return { 'Js': self.Js.reshape( (self.n_windows, self.n_hashes)).astype(np.float32), From dab4592696681279314ae96d1118b14c51f5a28a Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Fri, 29 Jul 2022 16:20:35 +0200 Subject: [PATCH 11/16] uncomment viziphant import and function for visualization of offline and online UEA results --- elephant/test/test_online.py | 46 ++++++++++++++++++------------------ 1 file changed, 23 insertions(+), 23 deletions(-) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index b6863dda0..413a03a2e 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -6,7 +6,7 @@ import neo import numpy as np import quantities as pq -import viziphant +# import viziphant from elephant.datasets import download_datasets from elephant.online import OnlineUnitaryEventAnalysis @@ -53,28 +53,28 @@ def _generate_spiketrains(freq, length, trigger_events, injection_pos, return spiketrains, st1, st2 -def _visualize_results_of_offline_and_online_uea( - spiketrains, online_trials, ue_dict_offline, ue_dict_online, alpha): - # rescale input-params 'bin_size', win_size' and 'win_step' to ms, - # because plot_ue() expects these parameters in ms - ue_dict_offline["input_parameters"]["bin_size"].units = pq.ms - ue_dict_offline["input_parameters"]["win_size"].units = pq.ms - ue_dict_offline["input_parameters"]["win_step"].units = pq.ms - viziphant.unitary_event_analysis.plot_ue( - spiketrains, Js_dict=ue_dict_offline, significance_level=alpha, - unit_real_ids=['1', '2'], suptitle="offline") - # plt.show() - # reorder and rename indices-dict of ue_dict_online, if only the last - # n-trials were saved; indices-entries of unused trials are overwritten - if len(online_trials) < len(spiketrains): - _diff_n_trials = len(spiketrains) - len(online_trials) - for i in range(len(online_trials)): - ue_dict_online["indices"][f"trial{i}"] = \ - ue_dict_online["indices"].pop(f"trial{i+_diff_n_trials}") - viziphant.unitary_event_analysis.plot_ue( - online_trials, Js_dict=ue_dict_online, significance_level=alpha, - unit_real_ids=['1', '2'], suptitle="online") - # plt.show() +# def _visualize_results_of_offline_and_online_uea( +# spiketrains, online_trials, ue_dict_offline, ue_dict_online, alpha): +# # rescale input-params 'bin_size', win_size' and 'win_step' to ms, +# # because plot_ue() expects these parameters in ms +# ue_dict_offline["input_parameters"]["bin_size"].units = pq.ms +# ue_dict_offline["input_parameters"]["win_size"].units = pq.ms +# ue_dict_offline["input_parameters"]["win_step"].units = pq.ms +# viziphant.unitary_event_analysis.plot_ue( +# spiketrains, Js_dict=ue_dict_offline, significance_level=alpha, +# unit_real_ids=['1', '2'], suptitle="offline") +# # plt.show() +# # reorder and rename indices-dict of ue_dict_online, if only the last +# # n-trials were saved; indices-entries of unused trials are overwritten +# if len(online_trials) < len(spiketrains): +# _diff_n_trials = len(spiketrains) - len(online_trials) +# for i in range(len(online_trials)): +# ue_dict_online["indices"][f"trial{i}"] = \ +# ue_dict_online["indices"].pop(f"trial{i+_diff_n_trials}") +# viziphant.unitary_event_analysis.plot_ue( +# online_trials, Js_dict=ue_dict_online, significance_level=alpha, +# unit_real_ids=['1', '2'], suptitle="online") +# plt.show() def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, From 30a51b35eadf0a23d279f402f298e984a8e55f23 Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Thu, 1 Sep 2022 16:59:23 +0200 Subject: [PATCH 12/16] code refactoring: replaced abbreviations with full name); simlified some calculations; fixed PEP issues (not all) --- elephant/online.py | 426 +++++++++++++++++++---------------- elephant/test/test_online.py | 191 +++++++--------- 2 files changed, 308 insertions(+), 309 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index c95b8d7e2..c672bcb8c 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -22,7 +22,7 @@ class OnlineUnitaryEventAnalysis: Parameters ---------- - bw_size : pq.Quantity + bin_window_size : pq.Quantity Size of the bin window, which is used to bin the spike trains. trigger_events : pq.Quantity Quantity array of time points around which the trials are defined. @@ -34,11 +34,11 @@ class OnlineUnitaryEventAnalysis: trigger_post_size : pq.Quantity Interval size after the trigger event. It is used with 'trigger_pre_size' to define the trial. - saw_size : pq.Quantity + sliding_analysis_window_size : pq.Quantity Size of the sliding analysis window, which is used to perform the UEA on the trial segments. It advances with a step size defined by - 'saw_step'. - saw_step : pq.Quantity + 'sliding_analysis_window_step'. + sliding_analysis_window_step : pq.Quantity Size of the step which is used to advance the sliding analysis window to its next position / next trial segment to analyze. n_neurons : int @@ -59,7 +59,7 @@ class OnlineUnitaryEventAnalysis: Attributes ---------- - data_available_in_mv : boolean + data_available_in_memory_window : boolean Reflects the status of spike trains in the memory window. It is `True', when spike trains are in the memory window which were not yet analyzed. Otherwise, it is `False'. @@ -74,32 +74,32 @@ class OnlineUnitaryEventAnalysis: Reflects the status of the trial defining events in the 'trigger_events' list. It is `True`, when there are events left, which were not analyzed yet. Otherwise, it is `False`. - mw : list of lists + memory_window : list of lists Contains for each neuron the spikes which are currently available in the memory window. * 0-axis --> Neurons * 1-axis --> Spike times - tw_size : pq.Quantity + trial_window_size : pq.Quantity The size of the trial window. It is the sum of 'trigger_pre_size' and 'trigger_post_size'. - tw : list of lists + trial_window : list of lists Contains for each neuron the spikes which belong to the current trial and are available in the memory window. * 0-axis --> Neurons * 1-axis --> Spike times - tw_counter : int + trial_counter : int Counts how many trails are yet analyzed. n_bins : int Number of bins which are used for the binning of the spikes of a trial. - bw : np.array of booleans + bin_window : np.array of booleans A binned representation of the current trial window. `True` indicates the presence of a spike in the bin and `False` indicates absence of a spike. * 0-axis --> Neurons * 1-axis --> Index position of the bin - saw_pos_counter : int + sliding_analysis_window_position : int Represents the current position of the sliding analysis window. - n_windows : int + n_sliding_analysis_windows : int Total number of positions of the sliding analysis window. n_trials : int Total number of trials to analyze. @@ -163,9 +163,9 @@ class OnlineUnitaryEventAnalysis: update_uea(spiketrains, events) Updates the entries of the result dictionary by processing the new arriving 'spiketrains' and trial defining trigger 'events'. - reset(bw_size, trigger_events, trigger_pre_size, trigger_post_size, - saw_size, saw_step, n_neurons, pattern_hash, time_unit, - save_n_trials) + reset(bin_window_size, trigger_events, trigger_pre_size, trigger_post_size, + sliding_analysis_window_size, sliding_analysis_window_step, + n_neurons, pattern_hash, time_unit, save_n_trials) Resets all class attributes to their initial (default) value. It is actually a re-initialization which allows parameter adjustments. get_all_saved_trials() @@ -176,26 +176,18 @@ class OnlineUnitaryEventAnalysis: ------- see 'get_results()' in Methods section - Notes - ----- - Common abbreviations which are used in both code and documentation: - bw = bin window - tw = trial window - saw = sliding analysis window - idw = incoming data window - mw = memory window - """ - def __init__(self, bw_size=0.005, trigger_events=None, + def __init__(self, bin_window_size=0.005, trigger_events=None, trigger_pre_size=0.5, trigger_post_size=0.5, - saw_size=0.1, saw_step=0.005, n_neurons=2, + sliding_analysis_window_size=0.1, + sliding_analysis_window_step=0.005, n_neurons=2, pattern_hash=None, time_unit='s', save_n_trials=None): """ Constructor. Initializes all attributes of the new instance. """ # state controlling booleans for the updating algorithm - self.data_available_in_mv = None + self.data_available_in_memory_window = None self.waiting_for_new_trigger = True self.trigger_events_left_over = True @@ -206,7 +198,7 @@ def __init__(self, bw_size=0.005, trigger_events=None, f"Use `ms` or `s` instead!", category=UserWarning) self.time_unit = pq.Quantity(1, time_unit) - self.bw_size = pq.Quantity(bw_size, self.time_unit) + self.bin_window_size = pq.Quantity(bin_window_size, self.time_unit) if trigger_events is None: self.trigger_events = [] else: @@ -214,8 +206,10 @@ def __init__(self, bw_size=0.005, trigger_events=None, self.time_unit).tolist() self.trigger_pre_size = pq.Quantity(trigger_pre_size, self.time_unit) self.trigger_post_size = pq.Quantity(trigger_post_size, self.time_unit) - self.saw_size = pq.Quantity(saw_size, self.time_unit) # multiple of bw_size - self.saw_step = pq.Quantity(saw_step, self.time_unit) # multiple of bw_size + self.sliding_analysis_window_size = \ + pq.Quantity(sliding_analysis_window_size, self.time_unit) + self.sliding_analysis_window_step = \ + pq.Quantity(sliding_analysis_window_step, self.time_unit) self.n_neurons = n_neurons if pattern_hash is None: pattern = [1] * n_neurons @@ -224,22 +218,25 @@ def __init__(self, bw_size=0.005, trigger_events=None, self.pattern_hash = [int(self.pattern_hash)] self.save_n_trials = save_n_trials - # initialize helper variables for the memory window (mw) - self.mw = [[] for _ in range(self.n_neurons)] # list of all spiketimes + # initialize the memory window + self.memory_window = [[] for _ in range(self.n_neurons)] - # initialize helper variables for the trial window (tw) - self.tw_size = self.trigger_pre_size + self.trigger_post_size - self.tw = [[] for _ in range(self.n_neurons)] # pointer to slice of mw - self.tw_counter = 0 + # initialize the trial window and helper variables + self.trial_window_size = self.trigger_pre_size + self.trigger_post_size + self.trial_window = [[] for _ in range(self.n_neurons)] + self.trial_counter = 0 - # initialize helper variables for the bin window (bw) + # initialize the bin window and helper variable self.n_bins = None - self.bw = None # binned copy of tw + self.bin_window = None - # initialize helper variable for the sliding analysis window (saw) - self.saw_pos_counter = 0 - self.n_windows = int(np.round( - (self.tw_size - self.saw_size + self.saw_step) / self.saw_step)) + # initialize the sliding analysis window and helper variable + self.sliding_analysis_window_position = 0 + self.n_sliding_analysis_windows = \ + int(np.round((self.trial_window_size - + self.sliding_analysis_window_size + + self.sliding_analysis_window_step) / + self.sliding_analysis_window_step)) # determine the number trials and the number of patterns (hashes) self.n_trials = len(self.trigger_events) @@ -252,24 +249,27 @@ def __init__(self, bw_size=0.005, trigger_events=None, # to facilitate a later comparison of the used parameters self.method = 'analytic_TrialByTrial' self.n_surrogates = 100 - self.input_parameters = dict(pattern_hash=self.pattern_hash, - bin_size=self.bw_size.rescale(pq.ms), - win_size=self.saw_size.rescale(pq.ms), - win_step=self.saw_step.rescale(pq.ms), - method=self.method, - t_start=0 * self.time_unit, - t_stop=self.tw_size, - n_surrogates=self.n_surrogates) + self.input_parameters = \ + dict(pattern_hash=self.pattern_hash, + bin_size=self.bin_window_size.rescale(pq.ms), + win_size=self.sliding_analysis_window_size.rescale(pq.ms), + win_step=self.sliding_analysis_window_step.rescale(pq.ms), + method=self.method, + t_start=0 * self.time_unit, + t_stop=self.trial_window_size, + n_surrogates=self.n_surrogates) # initialize the intermediate result arrays for the joint surprise - # (js), number of expected coincidences (n_exp), number of empirically + # (Js), number of expected coincidences (n_exp), number of empirically # found coincidences (n_emp), rate average of the analyzed neurons - # (rate_avg), as well as the indices of the saw position where - # coincidences appear + # (rate_avg), as well as the indices of the sliding analysis window + # position where coincidences appear self.Js, self.n_exp, self.n_emp = np.zeros( - (3, self.n_windows, self.n_hashes), dtype=np.float64) + (3, self.n_sliding_analysis_windows, self.n_hashes), + dtype=np.float64) self.rate_avg = np.zeros( - (self.n_windows, self.n_hashes, self.n_neurons), dtype=np.float64) + (self.n_sliding_analysis_windows, self.n_hashes, self.n_neurons), + dtype=np.float64) self.indices = defaultdict(list) def get_all_saved_trials(self): @@ -284,6 +284,7 @@ def get_all_saved_trials(self): : list of list of neo.SpikeTrain A nested list of trials, neurons and their neo.SpikeTrain objects, respectively. + """ return list(self.all_trials) @@ -292,7 +293,8 @@ def get_results(self): Return result dictionary. Prepares the dictionary entries by reshaping them into the correct - shape with the correct dtype. + shape with the correct datatype. Moreover, it calculates 'Js' (could + not be updated continuously) and scales 'n_exp' and 'rate_avg'. Returns ------- @@ -311,28 +313,30 @@ def get_results(self): """ for key in self.indices.keys(): self.indices[key] = np.hstack(self.indices[key]).flatten() - self.n_exp /= (self.saw_size / self.bw_size) + self.n_exp /= (self.sliding_analysis_window_size / + self.bin_window_size) p = self._pval(self.n_emp.astype(np.float64), self.n_exp.astype(np.float64)).flatten() self.Js = jointJ(p) - self.rate_avg = (self.rate_avg * (self.saw_size / self.bw_size)) / \ - (self.saw_size * self.tw_counter) + self.rate_avg = self.rate_avg / (self.bin_window_size * + self.trial_counter) return { - 'Js': self.Js.reshape( - (self.n_windows, self.n_hashes)).astype(np.float32), + 'Js': self.Js.reshape((self.n_sliding_analysis_windows, + self.n_hashes)).astype(np.float32), 'indices': self.indices, - 'n_emp': self.n_emp.reshape( - (self.n_windows, self.n_hashes)).astype(np.float32), - 'n_exp': self.n_exp.reshape( - (self.n_windows, self.n_hashes)).astype(np.float32), + 'n_emp': self.n_emp.reshape((self.n_sliding_analysis_windows, + self.n_hashes)).astype(np.float32), + 'n_exp': self.n_exp.reshape((self.n_sliding_analysis_windows, + self.n_hashes)).astype(np.float32), 'rate_avg': self.rate_avg.reshape( - (self.n_windows, self.n_hashes, self.n_neurons)).astype( - np.float32), + (self.n_sliding_analysis_windows, + self.n_hashes, self.n_neurons)).astype(np.float32), 'input_parameters': self.input_parameters} - def reset(self, bw_size=0.005 * pq.s, trigger_events=None, + def reset(self, bin_window_size=0.005 * pq.s, trigger_events=None, trigger_pre_size=0.5 * pq.s, trigger_post_size=0.5 * pq.s, - saw_size=0.1 * pq.s, saw_step=0.005 * pq.s, n_neurons=2, + sliding_analysis_window_size=0.1 * pq.s, + sliding_analysis_window_step=0.005 * pq.s, n_neurons=2, pattern_hash=None, time_unit=1 * pq.s, save_n_trials=None): """ Resets all class attributes to their initial value. @@ -346,15 +350,16 @@ def reset(self, bw_size=0.005 * pq.s, trigger_events=None, (same as for the constructor; see docstring of constructor for details) """ - self.__init__(bw_size, trigger_events, trigger_pre_size, - trigger_post_size, saw_size, saw_step, n_neurons, + self.__init__(bin_window_size, trigger_events, trigger_pre_size, + trigger_post_size, sliding_analysis_window_size, + sliding_analysis_window_step, n_neurons, pattern_hash, time_unit, save_n_trials) - def update_uea(self, spiketrains, events=None, - t_start=None, t_stop=None, time_unit=None): + def update_uea(self, spiketrains, events=None, t_start=None, t_stop=None, + time_unit=None): """ Update unitary event analysis (UEA) with new arriving spike data from - the incoming data window (IDW). + the incoming data window (incoming data window). This method orchestrates the updating process. It saves the incoming 'spiketrains' into the memory window (MW) and adds also the new @@ -370,14 +375,14 @@ def update_uea(self, spiketrains, events=None, events : list of pq.Quantity or list of numpy.ndarray Time points of the trial defining trigger events. t_start : float - Start time of the IDW. + Start time of the incoming data window. Required if 'spiketrains' is a list of numpy.ndarray. t_stop : float - Stop time of the IDW. + Stop time of the incoming data window. Required if 'spiketrains' is a list of numpy.ndarray. time_unit : string - Name of the time unit used for representing the spikes in the IDW. - E.g. 's', 'ms'. + Name of the time unit used for representing the spikes in the + incoming data window, e.g. 's', 'ms'. Warns ----- @@ -390,15 +395,15 @@ def update_uea(self, spiketrains, events=None, Notes ----- Short summary of the different algorithm major states / substates: - 1. pre/post trial analysis: algorithm is waiting for IDW with - new trigger event - 1.1. IDW contains new trigger event - 1.2. IDW does not contain new trigger event - 2. within trial analysis: algorithm is waiting for IDW with - spikes of current trial - 2.1. IDW contains new trigger event - 2.2. IDW does not contain new trigger event, it just has new spikes - of the current trial + 1. pre/post trial analysis: algorithm is waiting for + incoming data window with new trigger event + 1.1. incoming data window contains new trigger event + 1.2. incoming data window does not contain new trigger event + 2. within trial analysis: algorithm is waiting for incoming data window + with spikes of current trial + 2.1. incoming data window contains new trigger event + 2.2. incoming data window does not contain new trigger event, it + just has new spikes of the current trial """ if isinstance(spiketrains[0], neo.SpikeTrain): @@ -418,8 +423,8 @@ def update_uea(self, spiketrains, events=None, units=time_unit).rescale(self.time_unit) for st in spiketrains] # extract relevant time information - idw_t_start = spiketrains[0].t_start - idw_t_stop = spiketrains[0].t_stop + t_start_incoming_data_window = spiketrains[0].t_start + t_stop_incoming_data_window = spiketrains[0].t_stop if events is None: events = np.array([]) @@ -432,74 +437,83 @@ def update_uea(self, spiketrains, events=None, self.trigger_events.sort() self.n_trials = len(self.trigger_events) - # save incoming spikes (IDW) into memory (MW) - self._save_idw_into_mw(spiketrains) - + # save spikes of incoming data window into memory window + self._save_incoming_data_window_into_memory_window(spiketrains) # analyse all trials which are available in the memory - self.data_available_in_mv = True - while self.data_available_in_mv: - if self.tw_counter == self.n_trials: + self.data_available_in_memory_window = True + while self.data_available_in_memory_window: + if self.trial_counter == self.n_trials: break if self.n_trials == 0: current_trigger_event = np.inf * self.time_unit next_trigger_event = np.inf * self.time_unit else: - current_trigger_event = self.trigger_events[self.tw_counter] - if self.tw_counter <= self.n_trials - 2: + current_trigger_event = self.trigger_events[self.trial_counter] + if self.trial_counter <= self.n_trials - 2: next_trigger_event = self.trigger_events[ - self.tw_counter + 1] + self.trial_counter + 1] else: next_trigger_event = np.inf * self.time_unit # # case 1: pre/post trial analysis, - # i.e. waiting for IDW with new trigger event + # i.e. waiting for incoming data window with new trigger event if self.waiting_for_new_trigger: - # # subcase 1: IDW contains trigger event - if (idw_t_start <= current_trigger_event) & \ - (current_trigger_event <= idw_t_stop): + # # subcase 1: incoming data window contains trigger event + if (t_start_incoming_data_window <= current_trigger_event) & \ + (current_trigger_event <= t_stop_incoming_data_window): self.waiting_for_new_trigger = False if self.trigger_events_left_over: - # define TW around trigger event - self._define_tw(trigger_event=current_trigger_event) - # apply BW to available data in TW - self._apply_bw_to_tw() - # move SAW over available data in TW - self._move_saw_over_tw(t_stop_idw=idw_t_stop) + # define trial window around trigger event + self._define_trial_window( + trigger_event=current_trigger_event) + # apply bin window to available data in trial window + self._binning_of_trial_window() + # move sliding analysis window over available data + # in trial window + self._move_sliding_analysis_window_over_trial_window( + t_stop_incoming_data_window= + t_stop_incoming_data_window) else: pass - # # subcase 2: IDW does not contain trigger event + # # subcase 2: incoming data window does not contain + # trigger event else: - self._move_mw( - new_t_start=idw_t_stop - self.trigger_pre_size) + self._move_memory_window(new_t_start= + t_stop_incoming_data_window - + self.trigger_pre_size) - # # Case 2: within trial analysis, - # i.e. waiting for new IDW with spikes of current trial + # # Case 2: within trial analysis, i.e. waiting for + # new incoming data window with spikes of current trial else: - # # Subcase 3: IDW contains new trigger event - if (idw_t_start <= next_trigger_event) & \ - (next_trigger_event <= idw_t_stop): + # # Subcase 3: incoming data window contains new trigger event + if (t_start_incoming_data_window <= next_trigger_event) & \ + (next_trigger_event <= t_stop_incoming_data_window): # check if an overlap between current / next trial exists - if self._check_tw_overlap( + if self._check_trial_window_overlap( current_trigger_event=current_trigger_event, next_trigger_event=next_trigger_event): warnings.warn( - f"Data in trial {self.tw_counter} will be analysed" - f" twice! Adjust the trigger events and/or " - f"the trial window size.", UserWarning) + f"Data in trial {self.trial_counter} will be " + f"analysed twice! Adjust the trigger events and/or" + f" the trial window size.", UserWarning) else: # no overlap exists pass - # # Subcase 4: IDW does not contain trigger event, - # i.e. just new spikes of the current trial + # # Subcase 4: incoming data window does not contain + # trigger event, i.e. just new spikes of the current trial else: pass if self.trigger_events_left_over: - # define trial TW around trigger event - self._define_tw(trigger_event=current_trigger_event) - # apply BW to available data in TW - self._apply_bw_to_tw() - # move SAW over available data in TW - self._move_saw_over_tw(t_stop_idw=idw_t_stop) + # define trial window around trigger event + self._define_trial_window(trigger_event= + current_trigger_event) + # apply bin window to available data in trial window + self._binning_of_trial_window() + # move sliding analysis window over available data + # in trial window + self._move_sliding_analysis_window_over_trial_window( + t_stop_incoming_data_window= + t_stop_incoming_data_window) else: pass @@ -528,24 +542,26 @@ def _pval(self, n_emp, n_exp): p = 1. - sc.gammaincc(n_emp, n_exp) return p - def _save_idw_into_mw(self, idw): + def _save_incoming_data_window_into_memory_window(self, + incoming_data_window): """ - Save in-incoming data window (IDW) into memory window (MW). + Save in-incoming data window into memory window. This method appends for each neuron all the spikes which are arriving - with 'idw' into the respective sub-list of 'mv'. + with 'incoming_data_window' into the respective sub-list of + 'memory_window'. Parameters --------- - idw : list of pq.Quantity arrays + incoming_data_window : list of pq.Quantity arrays * 0-axis --> Neurons * 1-axis --> Spike times """ for i in range(self.n_neurons): - self.mw[i] += idw[i].tolist() + self.memory_window[i] += incoming_data_window[i].tolist() - def _move_mw(self, new_t_start): + def _move_memory_window(self, new_t_start): """ Move memory window. @@ -561,16 +577,16 @@ def _move_mw(self, new_t_start): """ for i in range(self.n_neurons): - idx = np.where(new_t_start > self.mw[i])[0] + idx = np.where(new_t_start > self.memory_window[i])[0] # print(f"idx = {idx}") - if not len(idx) == 0: # move mv - self.mw[i] = self.mw[i][idx[-1] + 1:] - else: # keep mv - self.data_available_in_mv = False + if not len(idx) == 0: # move memory_window + self.memory_window[i] = self.memory_window[i][idx[-1] + 1:] + else: # keep memory_window + self.data_available_in_memory_window = False - def _define_tw(self, trigger_event): + def _define_trial_window(self, trigger_event): """ - Define trial window (TW) based on a trigger event. + Define trial window based on a trigger event. This method defines the trial window around the 'trigger_event', i.e. it sets the start and stop of the trial, so that it covers the @@ -588,10 +604,12 @@ def _define_tw(self, trigger_event): self.trial_start = trigger_event - self.trigger_pre_size self.trial_stop = trigger_event + self.trigger_post_size for i in range(self.n_neurons): - self.tw[i] = [t for t in self.mw[i] - if (self.trial_start <= t) & (t <= self.trial_stop)] + self.trial_window[i] = [t for t in self.memory_window[i] + if (self.trial_start <= t) & + (t <= self.trial_stop)] - def _check_tw_overlap(self, current_trigger_event, next_trigger_event): + def _check_trial_window_overlap(self, current_trigger_event, + next_trigger_event): """ Check if successive trials do overlap each other. @@ -621,9 +639,9 @@ def _check_tw_overlap(self, current_trigger_event, next_trigger_event): else: return False - def _apply_bw_to_tw(self): + def _binning_of_trial_window(self): """ - Apply bin window (BW) to trial window (TW). + Apply bin window to trial window. Perform the binning and clipping procedure on the trial window, i.e. if at least one spike is within a bin, it is occupied and @@ -631,21 +649,22 @@ def _apply_bw_to_tw(self): """ self.n_bins = int(((self.trial_stop - self.trial_start) / - self.bw_size).simplified.item()) - self.bw = np.zeros((1, self.n_neurons, self.n_bins), dtype=np.int32) + self.bin_window_size).simplified.item()) + self.bin_window = np.zeros((1, self.n_neurons, self.n_bins), + dtype=np.int32) spiketrains = [neo.SpikeTrain(np.array(st) * self.time_unit, t_start=self.trial_start, t_stop=self.trial_stop) - for st in self.tw] - bs = conv.BinnedSpikeTrain(spiketrains, t_start=self.trial_start, - t_stop=self.trial_stop, - bin_size=self.bw_size) - self.bw = bs.to_bool_array() - - def _set_saw_positions(self, t_start, t_stop, win_size, win_step, - bin_size): + for st in self.trial_window] + bst = conv.BinnedSpikeTrain(spiketrains, t_start=self.trial_start, + t_stop=self.trial_stop, + bin_size=self.bin_window_size) + self.bin_window = bst.to_bool_array() + + def _set_sliding_analysis_windows_positions(self, t_start, t_stop, + win_size, win_step, bin_size): """ - Set positions of the sliding analysis window (SAW). + Set positions of the sliding analysis window. Determines the positions of the sliding analysis window with respect to the used window size 'win_size' and the advancing step 'win_step'. Also @@ -677,8 +696,8 @@ def _set_saw_positions(self, t_start, t_stop, win_size, win_step, """ self.t_winpos = _winpos(t_start, t_stop, win_size, win_step, position='left-edge') - while len(self.t_winpos) != self.n_windows: - if len(self.t_winpos) > self.n_windows: + while len(self.t_winpos) != self.n_sliding_analysis_windows: + if len(self.t_winpos) > self.n_sliding_analysis_windows: self.t_winpos = _winpos(t_start, t_stop - win_step / 2, win_size, win_step, position='left-edge') @@ -696,9 +715,10 @@ def _set_saw_positions(self, t_start, t_stop, win_size, win_step, warnings.warn(f"The ratio between the win_step ({win_step}) and " f"the bin_size ({bin_size}) is not an integer") - def _move_saw_over_tw(self, t_stop_idw): + def _move_sliding_analysis_window_over_trial_window( + self, t_stop_incoming_data_window): """ - Move sliding analysis window (SAW) over trial window (TW). + Move sliding analysis window over trial window. This method iterates over each sliding analysis window position and applies at each position the unitary event analysis, i.e. within each @@ -710,8 +730,8 @@ def _move_saw_over_tw(self, t_stop_idw): Parameters ---------- - t_stop_idw : pq.Quantity - Time point at which the current incoming data window (IDW) ends. + t_stop_incoming_data_window : pq.Quantity + Time point at which the current incoming data window ends. Notes ----- @@ -725,71 +745,85 @@ def _move_saw_over_tw(self, t_stop_idw): based on different distributions. """ - # define saw positions - self._set_saw_positions( + # define sliding analysis window positions + self._set_sliding_analysis_windows_positions( t_start=self.trial_start, t_stop=self.trial_stop, - win_size=self.saw_size, win_step=self.saw_step, - bin_size=self.bw_size) + win_size=self.sliding_analysis_window_size, + win_step=self.sliding_analysis_window_step, + bin_size=self.bin_window_size) - # iterate over saw positions - for i in range(self.saw_pos_counter, self.n_windows): + # iterate over sliding analysis window positions + for i in range(self.sliding_analysis_window_position, + self.n_sliding_analysis_windows): p_realtime = self.t_winpos[i] p_bintime = self.t_winpos_bintime[i] - self.t_winpos_bintime[0] - # check if saw filled with data - if p_realtime + self.saw_size <= t_stop_idw: # saw is filled + # check if sliding analysis window is filled with data + # case 1: sliding analysis window is filled + if p_realtime + self.sliding_analysis_window_size <= \ + t_stop_incoming_data_window: mat_win = np.zeros((1, self.n_neurons, self.winsize_bintime)) - n_bins_in_current_saw = self.bw[ + n_bins_in_current_sliding_analysis_window = self.bin_window[ :, p_bintime:p_bintime + self.winsize_bintime].shape[1] - if n_bins_in_current_saw < self.winsize_bintime: + if n_bins_in_current_sliding_analysis_window < \ + self.winsize_bintime: mat_win[0] += np.pad( - self.bw[:, p_bintime:p_bintime + self.winsize_bintime], - (0, self.winsize_bintime - n_bins_in_current_saw), + self.bin_window[:, + p_bintime:p_bintime + self.winsize_bintime], + (0, self.winsize_bintime - + n_bins_in_current_sliding_analysis_window), "minimum")[0:2] else: mat_win[0] += \ - self.bw[:, p_bintime:p_bintime + self.winsize_bintime] + self.bin_window[ + :, p_bintime:p_bintime + self.winsize_bintime] Js_win, rate_avg, n_exp_win, n_emp_win, indices_lst = _UE( mat_win, pattern_hash=self.pattern_hash, method=self.method, n_surrogates=self.n_surrogates) self.rate_avg[i] += rate_avg self.n_exp[i] += (np.round( - n_exp_win * (self.saw_size / self.bw_size))).astype(int) + n_exp_win * (self.sliding_analysis_window_size / + self.bin_window_size))).astype(int) self.n_emp[i] += n_emp_win self.indices_lst = indices_lst if len(self.indices_lst[0]) > 0: - self.indices[f"trial{self.tw_counter}"].append( + self.indices[f"trial{self.trial_counter}"].append( self.indices_lst[0] + p_bintime) - else: # saw is empty / half-filled -> pause iteration - self.saw_pos_counter = i - self.data_available_in_mv = False + # case 2: sliding analysis window is empty / partially filled + else: # -> pause iteration + self.sliding_analysis_window_position = i + self.data_available_in_memory_window = False break - if i == self.n_windows - 1: # last SAW position finished - self.saw_pos_counter = 0 - # move MV after SAW is finished with analysis of one trial - self._move_mw(new_t_start=self.trigger_events[self.tw_counter] - + self.trigger_post_size) + # last sliding analysis window position finished + if i == self.n_sliding_analysis_windows - 1: + self.sliding_analysis_window_position = 0 + # move memory window after sliding analysis window is finished + # with analysis of one trial + self._move_memory_window( + new_t_start=self.trigger_events[self.trial_counter] + + self.trigger_post_size) # save analysed trial for visualization if self.save_n_trials: _trial_start = 0 * pq.s - _trial_stop = self.tw_size - _offset = self.trigger_events[self.tw_counter] - \ - self.trigger_pre_size + _trial_stop = self.trial_window_size + _offset = self.trigger_events[self.trial_counter] - \ + self.trigger_pre_size normalized_spike_times = [] for n in range(self.n_neurons): normalized_spike_times.append( - np.array(self.tw[n]) * self.time_unit - _offset) + np.array(self.trial_window[n]) * self.time_unit - + _offset) self.all_trials.append( [neo.SpikeTrain(normalized_spike_times[m], t_start=_trial_start, t_stop=_trial_stop, units=self.time_unit) for m in range(self.n_neurons)]) - # reset bw - self.bw = np.zeros_like(self.bw) - if self.tw_counter <= self.n_trials - 1: - self.tw_counter += 1 + # reset bin window + self.bin_window = np.zeros_like(self.bin_window) + if self.trial_counter <= self.n_trials - 1: + self.trial_counter += 1 else: self.waiting_for_new_trigger = True self.trigger_events_left_over = False - self.data_available_in_mv = False - print(f"tw_counter = {self.tw_counter}") # DEBUG-aid + self.data_available_in_memory_window = False + print(f"trial_counter = {self.trial_counter}") # DEBUG-aid diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index 413a03a2e..72a367150 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -2,15 +2,13 @@ import unittest from collections import defaultdict -# import matplotlib.pyplot as plt import neo import numpy as np import quantities as pq -# import viziphant from elephant.datasets import download_datasets from elephant.online import OnlineUnitaryEventAnalysis -from elephant.spike_train_generation import homogeneous_poisson_process +from elephant.spike_train_generation import StationaryPoissonProcess from elephant.unitary_event_analysis import jointJ_window_analysis @@ -21,12 +19,14 @@ def _generate_spiketrains(freq, length, trigger_events, injection_pos, Generate two spiketrains from a homogeneous Poisson process with injected coincidences. """ - st1 = homogeneous_poisson_process(rate=freq, - t_start=(0*pq.s).rescale(time_unit), - t_stop=length.rescale(time_unit)) - st2 = homogeneous_poisson_process(rate=freq, - t_start=(0*pq.s.rescale(time_unit)), - t_stop=length.rescale(time_unit)) + st1 = StationaryPoissonProcess(rate=freq, + t_start=(0*pq.s).rescale(time_unit), + t_stop=length.rescale(time_unit) + ).generate_spiketrain() + st2 = StationaryPoissonProcess(rate=freq, + t_start=(0*pq.s.rescale(time_unit)), + t_stop=length.rescale(time_unit) + ).generate_spiketrain() # inject 10 coincidences within a 0.1s interval for each trial injection = (np.linspace(0, 0.1, 10)*pq.s).rescale(time_unit) all_injections = np.array([]) @@ -53,48 +53,25 @@ def _generate_spiketrains(freq, length, trigger_events, injection_pos, return spiketrains, st1, st2 -# def _visualize_results_of_offline_and_online_uea( -# spiketrains, online_trials, ue_dict_offline, ue_dict_online, alpha): -# # rescale input-params 'bin_size', win_size' and 'win_step' to ms, -# # because plot_ue() expects these parameters in ms -# ue_dict_offline["input_parameters"]["bin_size"].units = pq.ms -# ue_dict_offline["input_parameters"]["win_size"].units = pq.ms -# ue_dict_offline["input_parameters"]["win_step"].units = pq.ms -# viziphant.unitary_event_analysis.plot_ue( -# spiketrains, Js_dict=ue_dict_offline, significance_level=alpha, -# unit_real_ids=['1', '2'], suptitle="offline") -# # plt.show() -# # reorder and rename indices-dict of ue_dict_online, if only the last -# # n-trials were saved; indices-entries of unused trials are overwritten -# if len(online_trials) < len(spiketrains): -# _diff_n_trials = len(spiketrains) - len(online_trials) -# for i in range(len(online_trials)): -# ue_dict_online["indices"][f"trial{i}"] = \ -# ue_dict_online["indices"].pop(f"trial{i+_diff_n_trials}") -# viziphant.unitary_event_analysis.plot_ue( -# online_trials, Js_dict=ue_dict_online, significance_level=alpha, -# unit_real_ids=['1', '2'], suptitle="online") -# plt.show() - - -def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, - length_remainder, events=None, - st_type="list_of_neo.SpikeTrain"): +def _simulate_buffered_reading(n_buffers, ouea, st1, st2, + incoming_data_window_size, length_remainder, + events=None, st_type="list_of_neo.SpikeTrain"): if events is None: events = np.array([]) for i in range(n_buffers): - buff_t_start = i * IDW_length + buff_t_start = i * incoming_data_window_size if length_remainder > 1e-7 and i == n_buffers - 1: - buff_t_stop = i * IDW_length + length_remainder + buff_t_stop = i * incoming_data_window_size + length_remainder else: - buff_t_stop = i * IDW_length + IDW_length + buff_t_stop = i * incoming_data_window_size + \ + incoming_data_window_size events_in_buffer = np.array([]) if len(events) > 0: idx_events_in_buffer = (events >= buff_t_start) & \ (events <= buff_t_stop) - events_in_buffer = events[idx_events_in_buffer]#.tolist() + events_in_buffer = events[idx_events_in_buffer] events = events[np.logical_not(idx_events_in_buffer)] if st_type == "list_of_neo.SpikeTrain": @@ -106,8 +83,10 @@ def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, elif st_type == "list_of_numpy_array": ouea.update_uea( spiketrains=[ - st1.time_slice(t_start=buff_t_start, t_stop=buff_t_stop).magnitude, - st2.time_slice(t_start=buff_t_start, t_stop=buff_t_stop).magnitude], + st1.time_slice(t_start=buff_t_start, t_stop=buff_t_stop + ).magnitude, + st2.time_slice(t_start=buff_t_start, t_stop=buff_t_stop + ).magnitude], events=events_in_buffer, t_start=buff_t_start, t_stop=buff_t_stop, time_unit=st1.units) else: @@ -115,12 +94,6 @@ def _simulate_buffered_reading(n_buffers, ouea, st1, st2, IDW_length, "Use either list of neo.SpikeTrains or " "list of numpy arrays") print(f"#buffer = {i}") # DEBUG-aid - # # aid to create timelapses - # result_dict = ouea.get_results() - # viziphant.unitary_event_analysis.plot_ue( - # spiketrains[:i+1], Js_dict=result_dict, significance_level=0.05, - # unit_real_ids=['1', '2']) - # plt.savefig(f"plots/timelapse_UE/ue_real_data_buff_{i}.pdf") def _load_real_data(filepath, n_trials, trial_length, time_unit): @@ -247,21 +220,21 @@ def _test_unitary_events_analysis_with_real_data( random.seed(1224) # set relevant variables of this TestCase - n_trials = 36 # determined by real data - TW_length = (2.1 * pq.s).rescale(time_unit) # determined by real data + n_trials = 36 + trial_window_length = (2.1 * pq.s).rescale(time_unit) IDW_length = idw_length.rescale(time_unit) noise_length = (0. * pq.s).rescale(time_unit) trigger_events = (np.arange(0., n_trials * 2.1, 2.1) * pq.s).rescale( time_unit) n_buffers, length_remainder = _calculate_n_buffers( - n_trials=n_trials, tw_length=TW_length, + n_trials=n_trials, tw_length=trial_window_length, noise_length=noise_length, idw_length=IDW_length) # load data and extract spiketrains # 36 trials with 2.1s length and 0s background noise in between trials spiketrains, neo_st1, neo_st2 = _load_real_data( - filepath=self.filepath, n_trials=n_trials, trial_length=TW_length, - time_unit=time_unit) + filepath=self.filepath, n_trials=n_trials, + trial_length=trial_window_length, time_unit=time_unit) # perform standard unitary events analysis ue_dict = jointJ_window_analysis( @@ -283,21 +256,23 @@ def _test_unitary_events_analysis_with_real_data( ouea = None if st_type == "list_of_neo.SpikeTrain": ouea = OnlineUnitaryEventAnalysis( - bw_size=(0.005 * pq.s).rescale(time_unit), + bin_window_size=(0.005 * pq.s).rescale(time_unit), trigger_pre_size=(0. * pq.s).rescale(time_unit), trigger_post_size=(2.1 * pq.s).rescale(time_unit), - saw_size=(0.1 * pq.s).rescale(time_unit), - saw_step=(0.005 * pq.s).rescale(time_unit), + sliding_analysis_window_size=(0.1 * pq.s).rescale(time_unit), + sliding_analysis_window_step=(0.005 * pq.s).rescale(time_unit), trigger_events=init_events, time_unit=time_unit, save_n_trials=_last_n_trials) elif st_type == "list_of_numpy_array": ouea = OnlineUnitaryEventAnalysis( - bw_size=(0.005 * pq.s).rescale(time_unit).magnitude, + bin_window_size=(0.005 * pq.s).rescale(time_unit).magnitude, trigger_pre_size=(0. * pq.s).rescale(time_unit).magnitude, trigger_post_size=(2.1 * pq.s).rescale(time_unit).magnitude, - saw_size=(0.1 * pq.s).rescale(time_unit).magnitude, - saw_step=(0.005 * pq.s).rescale(time_unit).magnitude, + sliding_analysis_window_size= + (0.1 * pq.s).rescale(time_unit).magnitude, + sliding_analysis_window_step= + (0.005 * pq.s).rescale(time_unit).magnitude, trigger_events=init_events.magnitude, time_unit=time_unit.__str__().split(" ")[1], save_n_trials=_last_n_trials) @@ -310,8 +285,9 @@ def _test_unitary_events_analysis_with_real_data( # i.e. loop over spiketrain list and call update_ue() _simulate_buffered_reading( n_buffers=n_buffers, ouea=ouea, st1=neo_st1, st2=neo_st2, - IDW_length=IDW_length, length_remainder=length_remainder, - events=reading_events, st_type=st_type) + incoming_data_window_size=IDW_length, + length_remainder=length_remainder, events=reading_events, + st_type=st_type) ue_dict_online = ouea.get_results() # assert equality between result dicts of standard / online ue version @@ -323,38 +299,31 @@ def _test_unitary_events_analysis_with_real_data( last_n_trials=_last_n_trials, passed_trials=spiketrains, saved_trials=ouea.get_all_saved_trials()) - # visualize results of online and standard UEA for real data - # _visualize_results_of_offline_and_online_uea( - # spiketrains=spiketrains, - # online_trials=ouea.get_all_saved_trials(), - # ue_dict_offline=ue_dict, - # ue_dict_online=ue_dict_online, alpha=0.05) - return ouea def _test_unitary_events_analysis_with_artificial_data( self, idw_length, method="pass_events_at_initialization", time_unit=1 * pq.s, st_type="list_of_neo.SpikeTrain"): - # Fix random seed to guarantee fixed output + # fix random seed to guarantee fixed output random.seed(1224) # set relevant variables of this TestCase n_trials = 40 - TW_length = (1 * pq.s).rescale(time_unit) + trial_window_length = (1 * pq.s).rescale(time_unit) noise_length = (1.5 * pq.s).rescale(time_unit) - IDW_length = idw_length.rescale(time_unit) + incoming_data_window_size = idw_length.rescale(time_unit) trigger_events = (np.arange(0., n_trials*2.5, 2.5) * pq.s).rescale( time_unit) trigger_pre_size = (0. * pq.s).rescale(time_unit) trigger_post_size = (1. * pq.s).rescale(time_unit) n_buffers, length_remainder = _calculate_n_buffers( - n_trials=n_trials, tw_length=TW_length, - noise_length=noise_length, idw_length=IDW_length) + n_trials=n_trials, tw_length=trial_window_length, + noise_length=noise_length, idw_length=incoming_data_window_size) # create two long random homogeneous poisson spiketrains representing # 40 trials with 1s length and 1.5s background noise in between trials spiketrains, st1_long, st2_long = _generate_spiketrains( - freq=5*pq.Hz, length=(TW_length+noise_length)*n_trials, + freq=5*pq.Hz, length=(trial_window_length+noise_length)*n_trials, trigger_events=trigger_events, injection_pos=(0.6 * pq.s).rescale(time_unit), trigger_pre_size=trigger_pre_size, @@ -381,21 +350,23 @@ def _test_unitary_events_analysis_with_artificial_data( ouea = None if st_type == "list_of_neo.SpikeTrain": ouea = OnlineUnitaryEventAnalysis( - bw_size=(0.005 * pq.s).rescale(time_unit), + bin_window_size=(0.005 * pq.s).rescale(time_unit), trigger_pre_size=trigger_pre_size, trigger_post_size=trigger_post_size, - saw_size=(0.1 * pq.s).rescale(time_unit), - saw_step=(0.005 * pq.s).rescale(time_unit), + sliding_analysis_window_size=(0.1 * pq.s).rescale(time_unit), + sliding_analysis_window_step=(0.005 * pq.s).rescale(time_unit), trigger_events=init_events, time_unit=time_unit, save_n_trials=_last_n_trials) elif st_type == "list_of_numpy_array": ouea = OnlineUnitaryEventAnalysis( - bw_size=(0.005 * pq.s).rescale(time_unit).magnitude, + bin_window_size=(0.005 * pq.s).rescale(time_unit).magnitude, trigger_pre_size=trigger_pre_size.magnitude, trigger_post_size=trigger_post_size.magnitude, - saw_size=(0.1 * pq.s).rescale(time_unit).magnitude, - saw_step=(0.005 * pq.s).rescale(time_unit).magnitude, + sliding_analysis_window_size= + (0.1 * pq.s).rescale(time_unit).magnitude, + sliding_analysis_window_step= + (0.005 * pq.s).rescale(time_unit).magnitude, trigger_events=init_events.magnitude, time_unit=time_unit.__str__().split(" ")[1], save_n_trials=_last_n_trials) @@ -408,7 +379,8 @@ def _test_unitary_events_analysis_with_artificial_data( # i.e. loop over spiketrain list and call update_ue() _simulate_buffered_reading( n_buffers=n_buffers, ouea=ouea, st1=st1_long, st2=st2_long, - IDW_length=IDW_length, length_remainder=length_remainder, + incoming_data_window_size=incoming_data_window_size, + length_remainder=length_remainder, events=reading_events, st_type=st_type) ue_dict_online = ouea.get_results() @@ -421,17 +393,10 @@ def _test_unitary_events_analysis_with_artificial_data( last_n_trials=_last_n_trials, passed_trials=spiketrains, saved_trials=ouea.get_all_saved_trials()) - # visualize results of online and standard UEA for artificial data - # _visualize_results_of_offline_and_online_uea( - # spiketrains=spiketrains, - # online_trials=ouea.get_all_saved_trials(), - # ue_dict_offline=ue_dict, - # ue_dict_online=ue_dict_online, alpha=0.01) - return ouea - # test: trial window > in-coming data window (TW > IDW) - def test_TW_larger_IDW_artificial_data(self): + # test: trial window > incoming data window + def test_trial_window_larger_IDW_artificial_data(self): """Test, if online UE analysis is correct when the trial window is larger than the in-coming data window with artificial data.""" idw_length = ([0.995, 0.8, 0.6, 0.3, 0.1, 0.05]*pq.s).rescale( @@ -444,7 +409,7 @@ def test_TW_larger_IDW_artificial_data(self): st_type=st_type) self.doCleanups() - def test_TW_larger_IDW_real_data(self): + def test_trial_window_larger_IDW_real_data(self): """Test, if online UE analysis is correct when the trial window is larger than the in-coming data window with real data.""" idw_length = ([2.05, 2., 1.1, 0.8, 0.1, 0.05]*pq.s).rescale( @@ -457,8 +422,8 @@ def test_TW_larger_IDW_real_data(self): st_type=st_type) self.doCleanups() - # test: trial window = in-coming data window (TW = IDW) - def test_TW_as_large_as_IDW_real_data(self): + # test: trial window = incoming data window + def test_trial_window_as_large_as_IDW_real_data(self): """Test, if online UE analysis is correct when the trial window is as large as the in-coming data window with real data.""" idw_length = (2.1*pq.s).rescale(self.time_unit) @@ -469,7 +434,7 @@ def test_TW_as_large_as_IDW_real_data(self): st_type=st_type) self.doCleanups() - def test_TW_as_large_as_IDW_artificial_data(self): + def test_trial_window_as_large_as_IDW_artificial_data(self): """Test, if online UE analysis is correct when the trial window is as large as the in-coming data window with artificial data.""" idw_length = (1*pq.s).rescale(self.time_unit) @@ -480,8 +445,8 @@ def test_TW_as_large_as_IDW_artificial_data(self): st_type=st_type) self.doCleanups() - # test: trial window < in-coming data window (TW < IDW) - def test_TW_smaller_IDW_artificial_data(self): + # test: trial window < incoming data window + def test_trial_window_smaller_IDW_artificial_data(self): """Test, if online UE analysis is correct when the trial window is smaller than the in-coming data window with artificial data.""" idw_length = ([1.05, 1.1, 2, 10, 50, 100]*pq.s).rescale(self.time_unit) @@ -493,7 +458,7 @@ def test_TW_smaller_IDW_artificial_data(self): st_type=st_type) self.doCleanups() - def test_TW_smaller_IDW_real_data(self): + def test_trial_window_smaller_IDW_real_data(self): """Test, if online UE analysis is correct when the trial window is smaller than the in-coming data window with real data.""" idw_length = ([2.15, 2.2, 3, 10, 50, 75.6]*pq.s).rescale( @@ -503,7 +468,7 @@ def test_TW_smaller_IDW_real_data(self): with self.subTest(f"IDW = {idw} | st_type: {st_type}"): self._test_unitary_events_analysis_with_real_data( idw_length=idw, time_unit=self.time_unit, - st_type=st_type) + st_type=st_type) self.doCleanups() def test_pass_trigger_events_while_buffered_reading_real_data(self): @@ -536,17 +501,17 @@ def test_reset(self): ouea.reset() # check all class attributes with self.subTest(f"check 'bw_size'"): - self.assertEqual(ouea.bw_size, 0.005 * pq.s) + self.assertEqual(ouea.bin_window_size, 0.005 * pq.s) with self.subTest(f"check 'trigger_events'"): self.assertEqual(ouea.trigger_events, []) with self.subTest(f"check 'trigger_pre_size'"): self.assertEqual(ouea.trigger_pre_size, 0.5 * pq.s) with self.subTest(f"check 'trigger_post_size'"): self.assertEqual(ouea.trigger_post_size, 0.5 * pq.s) - with self.subTest(f"check 'saw_size'"): - self.assertEqual(ouea.saw_size, 0.1 * pq.s) - with self.subTest(f"check 'saw_step'"): - self.assertEqual(ouea.saw_step, 0.005 * pq.s) + with self.subTest(f"check 'sliding_analysis_window_size'"): + self.assertEqual(ouea.sliding_analysis_window_size, 0.1 * pq.s) + with self.subTest(f"check 'sliding_analysis_window_step'"): + self.assertEqual(ouea.sliding_analysis_window_step, 0.005 * pq.s) with self.subTest(f"check 'n_neurons'"): self.assertEqual(ouea.n_neurons, 2) with self.subTest(f"check 'pattern_hash'"): @@ -555,28 +520,28 @@ def test_reset(self): self.assertEqual(ouea.time_unit, 1*pq.s) with self.subTest(f"check 'save_n_trials'"): self.assertEqual(ouea.save_n_trials, None) - with self.subTest(f"check 'data_available_in_mv'"): - self.assertEqual(ouea.data_available_in_mv, None) + with self.subTest(f"check 'data_available_in_memory_window'"): + self.assertEqual(ouea.data_available_in_memory_window, None) with self.subTest(f"check 'waiting_for_new_trigger'"): self.assertEqual(ouea.waiting_for_new_trigger, True) with self.subTest(f"check 'trigger_events_left_over'"): self.assertEqual(ouea.trigger_events_left_over, True) with self.subTest(f"check 'mw'"): - np.testing.assert_equal(ouea.mw, [[] for _ in range(2)]) + np.testing.assert_equal(ouea.memory_window, [[] for _ in range(2)]) with self.subTest(f"check 'tw_size'"): - self.assertEqual(ouea.tw_size, 1 * pq.s) + self.assertEqual(ouea.trial_window_size, 1 * pq.s) with self.subTest(f"check 'tw'"): - np.testing.assert_equal(ouea.tw, [[] for _ in range(2)]) + np.testing.assert_equal(ouea.trial_window, [[] for _ in range(2)]) with self.subTest(f"check 'tw_counter'"): - self.assertEqual(ouea.tw_counter, 0) + self.assertEqual(ouea.trial_counter, 0) with self.subTest(f"check 'n_bins'"): self.assertEqual(ouea.n_bins, None) with self.subTest(f"check 'bw'"): - self.assertEqual(ouea.bw, None) - with self.subTest(f"check 'saw_pos_counter'"): - self.assertEqual(ouea.saw_pos_counter, 0) + self.assertEqual(ouea.bin_window, None) + with self.subTest(f"check 'sliding_analysis_window_pos_counter'"): + self.assertEqual(ouea.sliding_analysis_window_position, 0) with self.subTest(f"check 'n_windows'"): - self.assertEqual(ouea.n_windows, 181) + self.assertEqual(ouea.n_sliding_analysis_windows, 181) with self.subTest(f"check 'n_trials'"): self.assertEqual(ouea.n_trials, 0) with self.subTest(f"check 'n_hashes'"): From 3dff2346b1e15babe04eaaf17619e7b27ff6932a Mon Sep 17 00:00:00 2001 From: Maximilian Kramer Date: Fri, 2 Sep 2022 10:55:32 +0200 Subject: [PATCH 13/16] added nixio to requirements-test.txt, because it is needed to execute UnitTests of test_online.py --- requirements/requirements-tests.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements/requirements-tests.txt b/requirements/requirements-tests.txt index 55b033e90..9db87a5dc 100644 --- a/requirements/requirements-tests.txt +++ b/requirements/requirements-tests.txt @@ -1 +1,2 @@ -pytest \ No newline at end of file +pytest +nixio>=1.5.0 \ No newline at end of file From 21056d27827cc92344c5c42e3c9204d94e59acd8 Mon Sep 17 00:00:00 2001 From: Moritz-Alexander-Kern Date: Thu, 17 Nov 2022 16:30:00 +0100 Subject: [PATCH 14/16] commented out debug aid --- elephant/test/test_online.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index 72a367150..ae51c2fd5 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -93,7 +93,7 @@ def _simulate_buffered_reading(n_buffers, ouea, st1, st2, raise ValueError("undefined type for spiktrains representation! " "Use either list of neo.SpikeTrains or " "list of numpy arrays") - print(f"#buffer = {i}") # DEBUG-aid + # print(f"#buffer = {i}") # DEBUG-aid def _load_real_data(filepath, n_trials, trial_length, time_unit): From 4b6e95bbb9b1f8751ffeff1a4de8e96d1569480d Mon Sep 17 00:00:00 2001 From: Moritz-Alexander-Kern Date: Thu, 17 Nov 2022 16:41:01 +0100 Subject: [PATCH 15/16] fixed pep8 --- elephant/test/test_online.py | 29 +++++++++++++++-------------- 1 file changed, 15 insertions(+), 14 deletions(-) diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index ae51c2fd5..7d661426e 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -152,6 +152,7 @@ def setUpClass(cls): cls.st_types = ["list_of_neo.SpikeTrain", "list_of_numpy_array"] def setUp(self): + # do nothing pass def _assert_equality_of_passed_and_saved_trials( @@ -269,17 +270,17 @@ def _test_unitary_events_analysis_with_real_data( bin_window_size=(0.005 * pq.s).rescale(time_unit).magnitude, trigger_pre_size=(0. * pq.s).rescale(time_unit).magnitude, trigger_post_size=(2.1 * pq.s).rescale(time_unit).magnitude, - sliding_analysis_window_size= - (0.1 * pq.s).rescale(time_unit).magnitude, - sliding_analysis_window_step= - (0.005 * pq.s).rescale(time_unit).magnitude, + sliding_analysis_window_size=(0.1 * pq.s + ).rescale(time_unit).magnitude, + sliding_analysis_window_step=(0.005 * pq.s + ).rescale(time_unit).magnitude, trigger_events=init_events.magnitude, time_unit=time_unit.__str__().split(" ")[1], save_n_trials=_last_n_trials) else: - ValueError("undefined type for spiktrains representation! " - "Use either list of neo.SpikeTrains or " - "list of numpy arrays") + raise ValueError("undefined type for spiktrains representation! " + "Use either list of neo.SpikeTrains or " + "list of numpy arrays") # perform online unitary event analysis # simulate buffered reading/transport of spiketrains, # i.e. loop over spiketrain list and call update_ue() @@ -363,17 +364,17 @@ def _test_unitary_events_analysis_with_artificial_data( bin_window_size=(0.005 * pq.s).rescale(time_unit).magnitude, trigger_pre_size=trigger_pre_size.magnitude, trigger_post_size=trigger_post_size.magnitude, - sliding_analysis_window_size= - (0.1 * pq.s).rescale(time_unit).magnitude, - sliding_analysis_window_step= - (0.005 * pq.s).rescale(time_unit).magnitude, + sliding_analysis_window_size=(0.1 * pq.s + ).rescale(time_unit).magnitude, + sliding_analysis_window_step=(0.005 * pq.s + ).rescale(time_unit).magnitude, trigger_events=init_events.magnitude, time_unit=time_unit.__str__().split(" ")[1], save_n_trials=_last_n_trials) else: - ValueError("undefined type for spiktrains representation! " - "Use either list of neo.SpikeTrains or " - "list of numpy arrays") + raise ValueError("undefined type for spiktrains representation! " + "Use either list of neo.SpikeTrains or " + "list of numpy arrays") # perform online unitary event analysis # simulate buffered reading/transport of spiketrains, # i.e. loop over spiketrain list and call update_ue() From 162149ad149a3d28e392893b1a5d1b6465e21bac Mon Sep 17 00:00:00 2001 From: Moritz-Alexander-Kern Date: Thu, 17 Nov 2022 17:17:12 +0100 Subject: [PATCH 16/16] fixed pep8 --- elephant/online.py | 44 ++++++++++++++++-------------------- elephant/test/test_online.py | 8 +++---- 2 files changed, 24 insertions(+), 28 deletions(-) diff --git a/elephant/online.py b/elephant/online.py index c672bcb8c..70e033315 100644 --- a/elephant/online.py +++ b/elephant/online.py @@ -7,7 +7,7 @@ import scipy.special as sc import elephant.conversion as conv -from elephant.unitary_event_analysis import * +from elephant.unitary_event_analysis import hash_from_pattern, jointJ from elephant.unitary_event_analysis import _winpos, _bintime, _UE @@ -202,8 +202,8 @@ def __init__(self, bin_window_size=0.005, trigger_events=None, if trigger_events is None: self.trigger_events = [] else: - self.trigger_events = pq.Quantity(trigger_events, - self.time_unit).tolist() + self.trigger_events = pq.Quantity(trigger_events, + self.time_unit).tolist() self.trigger_pre_size = pq.Quantity(trigger_pre_size, self.time_unit) self.trigger_post_size = pq.Quantity(trigger_post_size, self.time_unit) self.sliding_analysis_window_size = \ @@ -417,11 +417,11 @@ def update_uea(self, spiketrains, events=None, t_start=None, t_stop=None, raise ValueError("'spiketrains' is a list of np.array(), thus" "'t_start', 't_stop' and 'time_unit' must be" "specified!") - else: - spiketrains = [neo.SpikeTrain( - times=st, t_start=t_start, t_stop=t_stop, - units=time_unit).rescale(self.time_unit) - for st in spiketrains] + + spiketrains = [neo.SpikeTrain(times=st, t_start=t_start, + t_stop=t_stop, units=time_unit + ).rescale(self.time_unit) + for st in spiketrains] # extract relevant time information t_start_incoming_data_window = spiketrains[0].t_start t_stop_incoming_data_window = spiketrains[0].t_stop @@ -474,14 +474,13 @@ def update_uea(self, spiketrains, events=None, t_start=None, t_stop=None, self._move_sliding_analysis_window_over_trial_window( t_stop_incoming_data_window= t_stop_incoming_data_window) - else: - pass + # # subcase 2: incoming data window does not contain # trigger event else: - self._move_memory_window(new_t_start= - t_stop_incoming_data_window - - self.trigger_pre_size) + self._move_memory_window( + new_t_start=t_stop_incoming_data_window - + self.trigger_pre_size) # # Case 2: within trial analysis, i.e. waiting for # new incoming data window with spikes of current trial @@ -501,12 +500,11 @@ def update_uea(self, spiketrains, events=None, t_start=None, t_stop=None, pass # # Subcase 4: incoming data window does not contain # trigger event, i.e. just new spikes of the current trial - else: - pass + if self.trigger_events_left_over: # define trial window around trigger event - self._define_trial_window(trigger_event= - current_trigger_event) + self._define_trial_window( + trigger_event=current_trigger_event) # apply bin window to available data in trial window self._binning_of_trial_window() # move sliding analysis window over available data @@ -514,8 +512,6 @@ def update_uea(self, spiketrains, events=None, t_start=None, t_stop=None, self._move_sliding_analysis_window_over_trial_window( t_stop_incoming_data_window= t_stop_incoming_data_window) - else: - pass def _pval(self, n_emp, n_exp): """ @@ -579,7 +575,7 @@ def _move_memory_window(self, new_t_start): for i in range(self.n_neurons): idx = np.where(new_t_start > self.memory_window[i])[0] # print(f"idx = {idx}") - if not len(idx) == 0: # move memory_window + if len(idx) != 0: # move memory_window self.memory_window[i] = self.memory_window[i][idx[-1] + 1:] else: # keep memory_window self.data_available_in_memory_window = False @@ -767,8 +763,8 @@ def _move_sliding_analysis_window_over_trial_window( if n_bins_in_current_sliding_analysis_window < \ self.winsize_bintime: mat_win[0] += np.pad( - self.bin_window[:, - p_bintime:p_bintime + self.winsize_bintime], + self.bin_window[:, p_bintime:p_bintime + + self.winsize_bintime], (0, self.winsize_bintime - n_bins_in_current_sliding_analysis_window), "minimum")[0:2] @@ -806,7 +802,7 @@ def _move_sliding_analysis_window_over_trial_window( _trial_start = 0 * pq.s _trial_stop = self.trial_window_size _offset = self.trigger_events[self.trial_counter] - \ - self.trigger_pre_size + self.trigger_pre_size normalized_spike_times = [] for n in range(self.n_neurons): normalized_spike_times.append( @@ -826,4 +822,4 @@ def _move_sliding_analysis_window_over_trial_window( self.waiting_for_new_trigger = True self.trigger_events_left_over = False self.data_available_in_memory_window = False - print(f"trial_counter = {self.trial_counter}") # DEBUG-aid + # print(f"trial_counter = {self.trial_counter}") # DEBUG-aid diff --git a/elephant/test/test_online.py b/elephant/test/test_online.py index 7d661426e..69fb1ebfa 100644 --- a/elephant/test/test_online.py +++ b/elephant/test/test_online.py @@ -145,9 +145,10 @@ def setUpClass(cls): cls.last_n_trials = 50 # download real data once and load it several times later - cls.repo_path = 'tutorials/tutorial_unitary_event_analysis/' \ - 'data/dataset-1.nix' - cls.filepath = download_datasets(cls.repo_path) + repo_path = 'tutorials/tutorial_unitary_event_analysis/' \ + 'data/dataset-1.nix' + cls.repo_path = repo_path + cls.filepath = download_datasets(repo_path) cls.st_types = ["list_of_neo.SpikeTrain", "list_of_numpy_array"] @@ -254,7 +255,6 @@ def _test_unitary_events_analysis_with_real_data( # create instance of OnlineUnitaryEventAnalysis _last_n_trials = min(self.last_n_trials, len(spiketrains)) - ouea = None if st_type == "list_of_neo.SpikeTrain": ouea = OnlineUnitaryEventAnalysis( bin_window_size=(0.005 * pq.s).rescale(time_unit),