From 4011266c071abf357b03f8ac2cb57c7fde9e6bbc Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 22 Oct 2024 10:42:41 +0200 Subject: [PATCH 01/12] add handling for trial object --- elephant/statistics.py | 65 ++++++++++++++++++++++++------------------ 1 file changed, 37 insertions(+), 28 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 0ab389572..bf74f781e 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -273,7 +273,7 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): r""" Evaluates the empirical Fano factor F of the spike counts of - a list of `neo.SpikeTrain` objects. + a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object. Given the vector v containing the observed spike counts (one per spike train) in the time window [t0, t1], F is defined as: @@ -288,9 +288,10 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): Parameters ---------- - spiketrains : list + spiketrains : list or elephant.trials.Trial List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of - spike times for which to compute the Fano factor of spike counts. + spike times for which to compute the Fano factor of spike counts, or + an `elephant.trials.Trial` object containing multiple spiketrain lists. warn_tolerance : pq.Quantity In case of a list of input neo.SpikeTrains, if their durations vary by more than `warn_tolerence` in their absolute values, throw a warning @@ -299,10 +300,11 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): Returns ------- - fano : float + fano : float or list of floats The Fano factor of the spike counts of the input spike trains. Returns np.NaN if an empty list is specified, or if all spike trains - are empty. + are empty. If a `Trial` object is provided, returns a list of Fano + factors, one for each trial. Raises ------ @@ -328,29 +330,36 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): 0.07142857142857142 """ - # Build array of spike counts (one per spike train) - spike_counts = np.array([len(st) for st in spiketrains]) - - # Compute FF - if all(count == 0 for count in spike_counts): - # empty list of spiketrains reaches this branch, and NaN is returned - return np.nan - - if all(isinstance(st, neo.SpikeTrain) for st in spiketrains): - if not is_time_quantity(warn_tolerance): - raise TypeError("'warn_tolerance' must be a time quantity.") - durations = [(st.t_stop - st.t_start).simplified.item() - for st in spiketrains] - durations_min = min(durations) - durations_max = max(durations) - if durations_max - durations_min > warn_tolerance.simplified.item(): - warnings.warn("Fano factor calculated for spike trains of " - "different duration (minimum: {_min}s, maximum " - "{_max}s).".format(_min=durations_min, - _max=durations_max)) - - fano = spike_counts.var() / spike_counts.mean() - return fano + def _compute_fano(spiketrains): + # Build array of spike counts (one per spike train) + spike_counts = np.array([len(st) for st in spiketrains]) + + # Compute FF + if all(count == 0 for count in spike_counts): + # empty list of spiketrains reaches this branch, and NaN is returned + return np.nan + + if all(isinstance(st, neo.SpikeTrain) for st in spiketrains): + if not is_time_quantity(warn_tolerance): + raise TypeError("'warn_tolerance' must be a time quantity.") + durations = [(st.t_stop - st.t_start).simplified.item() + for st in spiketrains] + durations_min = min(durations) + durations_max = max(durations) + if durations_max - durations_min > warn_tolerance.simplified.item(): + warnings.warn("Fano factor calculated for spike trains of " + "different duration (minimum: {_min}s, maximum " + "{_max}s).".format(_min=durations_min, + _max=durations_max)) + + fano = spike_counts.var() / spike_counts.mean() + return fano + + if isinstance(spiketrains, elephant.trials.Trials): + return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx)) + for idx in range(spiketrains.n_trials)] + else: + return _compute_fano(spiketrains) def __variation_check(v, with_nan): From 218a653d7cb6732e1146b9ee7249b33401f73969 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 22 Oct 2024 10:42:55 +0200 Subject: [PATCH 02/12] add tests --- elephant/test/test_statistics.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..4aa3dff8a 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -21,7 +21,7 @@ from elephant import statistics from elephant.spike_train_generation import StationaryPoissonProcess from elephant.test.test_trials import _create_trials_block -from elephant.trials import TrialsFromBlock +from elephant.trials import TrialsFromBlock, TrialsFromLists class IsiTestCase(unittest.TestCase): @@ -289,12 +289,13 @@ def setUp(self): # for cross-validation self.sp_counts[i] = len(st) + self.test_trials = TrialsFromLists([self.test_spiketrains, self.test_spiketrains]) + def test_fanofactor_spiketrains(self): # Test with list of spiketrains self.assertEqual( np.var(self.sp_counts) / np.mean(self.sp_counts), statistics.fanofactor(self.test_spiketrains)) - # One spiketrain in list st = self.test_spiketrains[0] self.assertEqual(statistics.fanofactor([st]), 0.0) @@ -352,6 +353,15 @@ def test_fanofactor_wrong_type(self): self.assertRaises(TypeError, statistics.fanofactor, [st1], warn_tolerance=1e-4) + def test_fanofactor_trials(self): + # Test with Trial object + self.assertEqual( + np.var(self.sp_counts) / np.mean(self.sp_counts), + statistics.fanofactor(self.test_trials)[0]) + self.assertEqual( + np.var(self.sp_counts) / np.mean(self.sp_counts), + statistics.fanofactor(self.test_trials)[1]) + class LVTestCase(unittest.TestCase): def setUp(self): From 29694faab2c4d6a8dc1d4fb3244fc77d6b1ebc7b Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:15:39 +0100 Subject: [PATCH 03/12] add tests for trial object pooling trials or spiketrains --- elephant/test/test_statistics.py | 53 ++++++++++++++++++++------------ 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 4aa3dff8a..f24adaa25 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -269,27 +269,28 @@ def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror( class FanoFactorTestCase(unittest.TestCase): - def setUp(self): + @classmethod + def setUpClass(cls): np.random.seed(100) num_st = 300 - self.test_spiketrains = [] - self.test_array = [] - self.test_quantity = [] - self.test_list = [] - self.sp_counts = np.zeros(num_st) + cls.test_spiketrains = [] + cls.test_array = [] + cls.test_quantity = [] + cls.test_list = [] + cls.sp_counts = np.zeros(num_st) for i in range(num_st): r = np.random.rand(np.random.randint(20) + 1) st = neo.core.SpikeTrain(r * pq.ms, t_start=0.0 * pq.ms, t_stop=20.0 * pq.ms) - self.test_spiketrains.append(st) - self.test_array.append(r) - self.test_quantity.append(r * pq.ms) - self.test_list.append(list(r)) + cls.test_spiketrains.append(st) + cls.test_array.append(r) + cls.test_quantity.append(r * pq.ms) + cls.test_list.append(list(r)) # for cross-validation - self.sp_counts[i] = len(st) + cls.sp_counts[i] = len(st) - self.test_trials = TrialsFromLists([self.test_spiketrains, self.test_spiketrains]) + cls.test_trials = TrialsFromLists([cls.test_spiketrains, cls.test_spiketrains]) def test_fanofactor_spiketrains(self): # Test with list of spiketrains @@ -353,14 +354,26 @@ def test_fanofactor_wrong_type(self): self.assertRaises(TypeError, statistics.fanofactor, [st1], warn_tolerance=1e-4) - def test_fanofactor_trials(self): - # Test with Trial object - self.assertEqual( - np.var(self.sp_counts) / np.mean(self.sp_counts), - statistics.fanofactor(self.test_trials)[0]) - self.assertEqual( - np.var(self.sp_counts) / np.mean(self.sp_counts), - statistics.fanofactor(self.test_trials)[1]) + def test_fanofactor_trials_pool_spiketrains(self): + results = statistics.fanofactor(self.test_trials, pool_spike_trains=True) + self.assertEqual(len(results), self.test_trials.n_trials) + for result in results: + self.assertEqual( + np.var(self.sp_counts) / np.mean(self.sp_counts), result) + + def test_fanofactor_trials_pool_trials(self): + results = statistics.fanofactor(self.test_trials, pool_trials=True) + self.assertEqual(len(results), self.test_trials.n_spiketrains_trial_by_trial[0]) + + def test_fanofactor_trials_pool_trials_pool_spiketrains(self): + results = statistics.fanofactor(self.test_trials, pool_trials=True, pool_spike_trains=True) + self.assertEqual(len(results), 1) + + def test_fanofactor_trials_pool_trials_false_pool_spiketrains_false(self): + results = statistics.fanofactor(self.test_trials, pool_trials=False, pool_spike_trains=False) + self.assertEqual(len(results), self.test_trials.n_trials) + for result in results: + self.assertEqual(len(result), self.test_trials.n_spiketrains_trial_by_trial[0]) class LVTestCase(unittest.TestCase): From 1711f43ec9c7db3463d140ad385518124fc8b75d Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:16:04 +0100 Subject: [PATCH 04/12] add parameters pool_trials, pool_spiketrains --- elephant/statistics.py | 39 +++++++++++++++++++++++++++++++-------- 1 file changed, 31 insertions(+), 8 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index dab39e832..f4a3c4031 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -74,7 +74,7 @@ import scipy.signal from numpy import ndarray from scipy.special import erf -from typing import Union +from typing import List, Union import elephant.conversion as conv import elephant.kernels as kernels @@ -270,7 +270,8 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): return rates -def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): +def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, elephant.trials.Trials], + warn_tolerance:pq.Quantity=0.1 * pq.ms, pool_trials:bool=False, pool_spike_trains:bool=False): r""" Evaluates the empirical Fano factor F of the spike counts of a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object. @@ -291,12 +292,19 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): spiketrains : list or elephant.trials.Trial List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of spike times for which to compute the Fano factor of spike counts, or - an `elephant.trials.Trial` object containing multiple spiketrain lists. + an `elephant.trials.Trial` object, here the behavior can be controlled with the + pool_trials and pool_spike_trains parameters. warn_tolerance : pq.Quantity In case of a list of input neo.SpikeTrains, if their durations vary by more than `warn_tolerence` in their absolute values, throw a warning (see Notes). Default: 0.1 ms + pool_trials : bool, optional + If True, pool spike trains across trials before computing the Fano factor. + Default: False + pool_spike_trains : bool, optional + If True, pool spike trains within each trial before computing the Fano factor. + Default: False Returns ------- @@ -304,7 +312,7 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): The Fano factor of the spike counts of the input spike trains. Returns np.NaN if an empty list is specified, or if all spike trains are empty. If a `Trial` object is provided, returns a list of Fano - factors, one for each trial. + factors. Raises ------ @@ -330,7 +338,7 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms): 0.07142857142857142 """ - def _compute_fano(spiketrains): + def _compute_fano(spiketrains: neo.SpikeTrain) -> float: # Build array of spike counts (one per spike train) spike_counts = np.array([len(st) for st in spiketrains]) @@ -356,9 +364,24 @@ def _compute_fano(spiketrains): return fano if isinstance(spiketrains, elephant.trials.Trials): - return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx)) - for idx in range(spiketrains.n_trials)] - else: + if not pool_trials and not pool_spike_trains: + return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)] + for idx in range(spiketrains.n_trials)] + if not pool_trials and pool_spike_trains: + return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx)) + for idx in range(spiketrains.n_trials)] + if pool_trials and not pool_spike_trains: + list_of_lists_of_spiketrains = [ + spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) + for trial_no in range(spiketrains.n_trials)] + return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no] + for trial_no in range(len(list_of_lists_of_spiketrains))]) + for st_no in range(len(list_of_lists_of_spiketrains[0]))] + if pool_trials and pool_spike_trains: + return [_compute_fano( + [spiketrain for trial_no in range(spiketrains.n_trials) + for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])] + else: # Legacy behavior return _compute_fano(spiketrains) From 7b0ccf92323bcd5fa756162da9084804ce22c414 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 14:32:49 +0100 Subject: [PATCH 05/12] add type check for pool parameters --- elephant/statistics.py | 14 +++++++++++--- elephant/test/test_statistics.py | 6 ++++++ 2 files changed, 17 insertions(+), 3 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index f4a3c4031..9217b7b4c 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -364,23 +364,31 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return fano if isinstance(spiketrains, elephant.trials.Trials): + # Check if parameters are of the correct type + if not isinstance(pool_trials, bool): + raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}") + elif not isinstance(pool_spike_trains, bool): + raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") if not pool_trials and not pool_spike_trains: return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)] for idx in range(spiketrains.n_trials)] - if not pool_trials and pool_spike_trains: + elif not pool_trials and pool_spike_trains: return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx)) for idx in range(spiketrains.n_trials)] - if pool_trials and not pool_spike_trains: + elif pool_trials and not pool_spike_trains: list_of_lists_of_spiketrains = [ spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) for trial_no in range(spiketrains.n_trials)] return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no] for trial_no in range(len(list_of_lists_of_spiketrains))]) for st_no in range(len(list_of_lists_of_spiketrains[0]))] - if pool_trials and pool_spike_trains: + elif pool_trials and pool_spike_trains: return [_compute_fano( [spiketrain for trial_no in range(spiketrains.n_trials) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])] + else: + raise TypeError(f"pool_spiketrains and pool_trials must be of type: bool, but are " + f"{type(pool_spike_trains)} and {type(pool_trials)}") else: # Legacy behavior return _compute_fano(spiketrains) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index f24adaa25..57ddeb96a 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -375,6 +375,12 @@ def test_fanofactor_trials_pool_trials_false_pool_spiketrains_false(self): for result in results: self.assertEqual(len(result), self.test_trials.n_spiketrains_trial_by_trial[0]) + def test_fanofactor_trials_pool_spike_trains_wrong_type(self): + self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trains="Wrong Type") + self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type") + self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type", + pool_spike_trains="Wrong Type") + class LVTestCase(unittest.TestCase): def setUp(self): From 1925bf306b9876cbae28a324ac9eda7574656cdd Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:41:29 +0100 Subject: [PATCH 06/12] refactor type annotations --- elephant/statistics.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 9217b7b4c..745453bb0 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -270,8 +270,9 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): return rates -def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, elephant.trials.Trials], - warn_tolerance:pq.Quantity=0.1 * pq.ms, pool_trials:bool=False, pool_spike_trains:bool=False): +def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray], elephant.trials.Trials], + warn_tolerance: pq.Quantity = 0.1 * pq.ms, pool_trials: bool = False, pool_spike_trains: bool = False + ) -> Union[float, List[float], List[List[float]]]: r""" Evaluates the empirical Fano factor F of the spike counts of a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object. @@ -296,7 +297,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, pool_trials and pool_spike_trains parameters. warn_tolerance : pq.Quantity In case of a list of input neo.SpikeTrains, if their durations vary by - more than `warn_tolerence` in their absolute values, throw a warning + more than `warn_tolerance` in their absolute values, throw a warning (see Notes). Default: 0.1 ms pool_trials : bool, optional @@ -308,7 +309,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, Returns ------- - fano : float or list of floats + fano : float, list of floats or list of list of floats The Fano factor of the spike counts of the input spike trains. Returns np.NaN if an empty list is specified, or if all spike trains are empty. If a `Trial` object is provided, returns a list of Fano @@ -338,7 +339,15 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], pq.Quantity, np.ndarray, 0.07142857142857142 """ - def _compute_fano(spiketrains: neo.SpikeTrain) -> float: + # Check if parameters are of the correct type + if not isinstance(pool_trials, bool): + raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}") + elif not isinstance(pool_spike_trains, bool): + raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") + elif not is_time_quantity(warn_tolerance): + raise TypeError("'warn_tolerance' must be a time quantity.") + + def _compute_fano(spiketrains: List[neo.SpikeTrain]) -> float: # Build array of spike counts (one per spike train) spike_counts = np.array([len(st) for st in spiketrains]) @@ -348,8 +357,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return np.nan if all(isinstance(st, neo.SpikeTrain) for st in spiketrains): - if not is_time_quantity(warn_tolerance): - raise TypeError("'warn_tolerance' must be a time quantity.") durations = [(st.t_stop - st.t_start).simplified.item() for st in spiketrains] durations_min = min(durations) @@ -364,11 +371,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return fano if isinstance(spiketrains, elephant.trials.Trials): - # Check if parameters are of the correct type - if not isinstance(pool_trials, bool): - raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}") - elif not isinstance(pool_spike_trains, bool): - raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") if not pool_trials and not pool_spike_trains: return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)] for idx in range(spiketrains.n_trials)] @@ -386,9 +388,6 @@ def _compute_fano(spiketrains: neo.SpikeTrain) -> float: return [_compute_fano( [spiketrain for trial_no in range(spiketrains.n_trials) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])] - else: - raise TypeError(f"pool_spiketrains and pool_trials must be of type: bool, but are " - f"{type(pool_spike_trains)} and {type(pool_trials)}") else: # Legacy behavior return _compute_fano(spiketrains) From 3c1eb5de785ad9e71b251659adaff2dfc385c82a Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 15:43:11 +0100 Subject: [PATCH 07/12] add to docstring --- elephant/statistics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elephant/statistics.py b/elephant/statistics.py index 745453bb0..1b8c50e71 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -320,6 +320,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[ TypeError If the input spiketrains are neo.SpikeTrain objects, but `warn_tolerance` is not a quantity. + If the parameters `pool_trials` or `pool_spike_trains` are not of type bool. Notes ----- From 6bb2fc3b42ffa842aff191ae0f311fedafcb9c4f Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Thu, 14 Nov 2024 16:52:44 +0100 Subject: [PATCH 08/12] add user warning and did refactoring of function --- elephant/statistics.py | 44 +++++++++++++++++--------------- elephant/test/test_statistics.py | 5 ++++ 2 files changed, 28 insertions(+), 21 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 1b8c50e71..552be24e6 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -293,7 +293,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[ spiketrains : list or elephant.trials.Trial List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of spike times for which to compute the Fano factor of spike counts, or - an `elephant.trials.Trial` object, here the behavior can be controlled with the + an `elephant.trials.Trial` object, here the behavior can be controlled with the pool_trials and pool_spike_trains parameters. warn_tolerance : pq.Quantity In case of a list of input neo.SpikeTrains, if their durations vary by @@ -325,7 +325,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[ Notes ----- The check for the equal duration of the input spike trains is performed - only if the input is of type`neo.SpikeTrain`: if you pass a numpy array, + only if the input is of type`neo.SpikeTrain`: if you pass e.g. a numpy array, please make sure that they all have the same duration manually. Examples @@ -346,30 +346,32 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[ elif not isinstance(pool_spike_trains, bool): raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") elif not is_time_quantity(warn_tolerance): - raise TypeError("'warn_tolerance' must be a time quantity.") + raise TypeError(f"'warn_tolerance' must be a time quantity, but got {type(warn_tolerance)}") + + def _check_input_spiketrains_durations(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], + List[np.ndarray]]) -> None: + if spiketrains and all(isinstance(st, neo.SpikeTrain) for st in spiketrains): + durations = np.array(tuple(st.duration for st in spiketrains)) + if np.max(durations) - np.min(durations) > warn_tolerance: + warnings.warn(f"Fano factor calculated for spike trains of " + f"different duration (minimum: {np.min(durations)}s, maximum " + f"{np.max(durations)}s).") + else: + warnings.warn(f"Spiketrains was of type {type(spiketrains)}, which does not support automatic duration" + f"check. The parameter 'warn_tolerance' will have no effect. Please ensure manually that" + f"all spike trains have the same duration.") - def _compute_fano(spiketrains: List[neo.SpikeTrain]) -> float: + def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray]]) -> float: + # Check spike train durations + _check_input_spiketrains_durations(spiketrains) # Build array of spike counts (one per spike train) - spike_counts = np.array([len(st) for st in spiketrains]) - + spike_counts = np.array(tuple(len(st) for st in spiketrains)) # Compute FF - if all(count == 0 for count in spike_counts): + if np.all(np.array(spike_counts) == 0): # empty list of spiketrains reaches this branch, and NaN is returned return np.nan - - if all(isinstance(st, neo.SpikeTrain) for st in spiketrains): - durations = [(st.t_stop - st.t_start).simplified.item() - for st in spiketrains] - durations_min = min(durations) - durations_max = max(durations) - if durations_max - durations_min > warn_tolerance.simplified.item(): - warnings.warn("Fano factor calculated for spike trains of " - "different duration (minimum: {_min}s, maximum " - "{_max}s).".format(_min=durations_min, - _max=durations_max)) - - fano = spike_counts.var() / spike_counts.mean() - return fano + else: + return spike_counts.var()/spike_counts.mean() if isinstance(spiketrains, elephant.trials.Trials): if not pool_trials and not pool_spike_trains: diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 57ddeb96a..7db8503e6 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -381,6 +381,11 @@ def test_fanofactor_trials_pool_spike_trains_wrong_type(self): self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type", pool_spike_trains="Wrong Type") + def test_fanofactor_warn_durations_manual_check(self): + st1 = [1, 2, 3] * pq.s + st2 = [1, 2, 3] * pq.s + self.assertWarns(UserWarning, statistics.fanofactor, (st1, st2)) + class LVTestCase(unittest.TestCase): def setUp(self): From 69a4d2499d5699fca3e8516543dcf5cad30872b8 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:17:09 +0100 Subject: [PATCH 09/12] remove pool trials parameter --- elephant/statistics.py | 18 +++--------------- elephant/test/test_statistics.py | 15 ++------------- 2 files changed, 5 insertions(+), 28 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 552be24e6..cfda0802e 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -271,7 +271,7 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None): def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray], elephant.trials.Trials], - warn_tolerance: pq.Quantity = 0.1 * pq.ms, pool_trials: bool = False, pool_spike_trains: bool = False + warn_tolerance: pq.Quantity = 0.1 * pq.ms, pool_trials: bool = False ) -> Union[float, List[float], List[List[float]]]: r""" Evaluates the empirical Fano factor F of the spike counts of @@ -303,9 +303,6 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[ pool_trials : bool, optional If True, pool spike trains across trials before computing the Fano factor. Default: False - pool_spike_trains : bool, optional - If True, pool spike trains within each trial before computing the Fano factor. - Default: False Returns ------- @@ -343,8 +340,6 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[ # Check if parameters are of the correct type if not isinstance(pool_trials, bool): raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}") - elif not isinstance(pool_spike_trains, bool): - raise TypeError(f"'pool_spike_trains' must be of type bool, but got {type(pool_spike_trains)}") elif not is_time_quantity(warn_tolerance): raise TypeError(f"'warn_tolerance' must be a time quantity, but got {type(warn_tolerance)}") @@ -374,23 +369,16 @@ def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], Li return spike_counts.var()/spike_counts.mean() if isinstance(spiketrains, elephant.trials.Trials): - if not pool_trials and not pool_spike_trains: + if not pool_trials: return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)] for idx in range(spiketrains.n_trials)] - elif not pool_trials and pool_spike_trains: - return [_compute_fano(spiketrains.get_spiketrains_from_trial_as_list(idx)) - for idx in range(spiketrains.n_trials)] - elif pool_trials and not pool_spike_trains: + elif pool_trials: list_of_lists_of_spiketrains = [ spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) for trial_no in range(spiketrains.n_trials)] return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no] for trial_no in range(len(list_of_lists_of_spiketrains))]) for st_no in range(len(list_of_lists_of_spiketrains[0]))] - elif pool_trials and pool_spike_trains: - return [_compute_fano( - [spiketrain for trial_no in range(spiketrains.n_trials) - for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)])] else: # Legacy behavior return _compute_fano(spiketrains) diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 7db8503e6..34df97542 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -354,23 +354,12 @@ def test_fanofactor_wrong_type(self): self.assertRaises(TypeError, statistics.fanofactor, [st1], warn_tolerance=1e-4) - def test_fanofactor_trials_pool_spiketrains(self): - results = statistics.fanofactor(self.test_trials, pool_spike_trains=True) - self.assertEqual(len(results), self.test_trials.n_trials) - for result in results: - self.assertEqual( - np.var(self.sp_counts) / np.mean(self.sp_counts), result) - def test_fanofactor_trials_pool_trials(self): results = statistics.fanofactor(self.test_trials, pool_trials=True) self.assertEqual(len(results), self.test_trials.n_spiketrains_trial_by_trial[0]) - def test_fanofactor_trials_pool_trials_pool_spiketrains(self): - results = statistics.fanofactor(self.test_trials, pool_trials=True, pool_spike_trains=True) - self.assertEqual(len(results), 1) - - def test_fanofactor_trials_pool_trials_false_pool_spiketrains_false(self): - results = statistics.fanofactor(self.test_trials, pool_trials=False, pool_spike_trains=False) + def test_fanofactor_trials_pool_trials_false(self): + results = statistics.fanofactor(self.test_trials, pool_trials=False) self.assertEqual(len(results), self.test_trials.n_trials) for result in results: self.assertEqual(len(result), self.test_trials.n_spiketrains_trial_by_trial[0]) From 5b1cf758ceffc310228a59ce183f0b956caa4584 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:18:14 +0100 Subject: [PATCH 10/12] add paramter ignored for spiketrainslist to docstring --- elephant/statistics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/elephant/statistics.py b/elephant/statistics.py index cfda0802e..5204b09f5 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -302,6 +302,7 @@ def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[ Default: 0.1 ms pool_trials : bool, optional If True, pool spike trains across trials before computing the Fano factor. + Note: If `spiketrains` is a list, this parameter is ignored. Default: False Returns From 303f3636f8d2cd6b6d2533a4981f27d17c081196 Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 10 Dec 2024 10:19:18 +0100 Subject: [PATCH 11/12] remove user warning to manually check duration for numpy arrays --- elephant/statistics.py | 4 ---- elephant/test/test_statistics.py | 5 ----- 2 files changed, 9 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index 5204b09f5..c8e4df95e 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -352,10 +352,6 @@ def _check_input_spiketrains_durations(spiketrains: Union[List[neo.SpikeTrain], warnings.warn(f"Fano factor calculated for spike trains of " f"different duration (minimum: {np.min(durations)}s, maximum " f"{np.max(durations)}s).") - else: - warnings.warn(f"Spiketrains was of type {type(spiketrains)}, which does not support automatic duration" - f"check. The parameter 'warn_tolerance' will have no effect. Please ensure manually that" - f"all spike trains have the same duration.") def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray]]) -> float: # Check spike train durations diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 34df97542..39b088850 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -370,11 +370,6 @@ def test_fanofactor_trials_pool_spike_trains_wrong_type(self): self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type", pool_spike_trains="Wrong Type") - def test_fanofactor_warn_durations_manual_check(self): - st1 = [1, 2, 3] * pq.s - st2 = [1, 2, 3] * pq.s - self.assertWarns(UserWarning, statistics.fanofactor, (st1, st2)) - class LVTestCase(unittest.TestCase): def setUp(self): From e9e677867d6fc7446f6e1e73bd9f3f3a5b9a4d0f Mon Sep 17 00:00:00 2001 From: Moritz Kern <92092328+Moritz-Alexander-Kern@users.noreply.github.com> Date: Tue, 10 Dec 2024 14:28:39 +0100 Subject: [PATCH 12/12] remove pool_trials arg --- elephant/statistics.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/elephant/statistics.py b/elephant/statistics.py index c8e4df95e..bdf810801 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -366,16 +366,12 @@ def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], Li return spike_counts.var()/spike_counts.mean() if isinstance(spiketrains, elephant.trials.Trials): - if not pool_trials: - return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)] - for idx in range(spiketrains.n_trials)] - elif pool_trials: - list_of_lists_of_spiketrains = [ - spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) - for trial_no in range(spiketrains.n_trials)] - return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no] - for trial_no in range(len(list_of_lists_of_spiketrains))]) - for st_no in range(len(list_of_lists_of_spiketrains[0]))] + list_of_lists_of_spiketrains = [ + spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no) + for trial_no in range(spiketrains.n_trials)] + return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no] + for trial_no in range(len(list_of_lists_of_spiketrains))]) + for st_no in range(len(list_of_lists_of_spiketrains[0]))] else: # Legacy behavior return _compute_fano(spiketrains)