Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] Integrate trials object with Fano factor #645

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 59 additions & 32 deletions elephant/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@
import scipy.signal
from numpy import ndarray
from scipy.special import erf
from typing import Union
from typing import List, Union

import elephant.conversion as conv
import elephant.kernels as kernels
Expand Down Expand Up @@ -270,10 +270,12 @@ def mean_firing_rate(spiketrain, t_start=None, t_stop=None, axis=None):
return rates


def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):
def fanofactor(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray], elephant.trials.Trials],
warn_tolerance: pq.Quantity = 0.1 * pq.ms, pool_trials: bool = False
) -> Union[float, List[float], List[List[float]]]:
r"""
Evaluates the empirical Fano factor F of the spike counts of
a list of `neo.SpikeTrain` objects.
a list of `neo.SpikeTrain` objects or `elephant.trials.Trial` object.

Given the vector v containing the observed spike counts (one per
spike train) in the time window [t0, t1], F is defined as:
Expand All @@ -288,32 +290,40 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):

Parameters
----------
spiketrains : list
spiketrains : list or elephant.trials.Trial
List of `neo.SpikeTrain` or `pq.Quantity` or `np.ndarray` or list of
spike times for which to compute the Fano factor of spike counts.
spike times for which to compute the Fano factor of spike counts, or
an `elephant.trials.Trial` object, here the behavior can be controlled with the
pool_trials and pool_spike_trains parameters.
warn_tolerance : pq.Quantity
In case of a list of input neo.SpikeTrains, if their durations vary by
more than `warn_tolerence` in their absolute values, throw a warning
more than `warn_tolerance` in their absolute values, throw a warning
(see Notes).
Default: 0.1 ms
pool_trials : bool, optional
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove parameter -> default behavior for trial object as input

If True, pool spike trains across trials before computing the Fano factor.
Note: If `spiketrains` is a list, this parameter is ignored.
Default: False
Moritz-Alexander-Kern marked this conversation as resolved.
Show resolved Hide resolved

Returns
-------
fano : float
fano : float, list of floats or list of list of floats
The Fano factor of the spike counts of the input spike trains.
Returns np.NaN if an empty list is specified, or if all spike trains
are empty.
are empty. If a `Trial` object is provided, returns a list of Fano
factors.

Raises
------
TypeError
If the input spiketrains are neo.SpikeTrain objects, but
`warn_tolerance` is not a quantity.
If the parameters `pool_trials` or `pool_spike_trains` are not of type bool.

Notes
-----
The check for the equal duration of the input spike trains is performed
only if the input is of type`neo.SpikeTrain`: if you pass a numpy array,
only if the input is of type`neo.SpikeTrain`: if you pass e.g. a numpy array,
please make sure that they all have the same duration manually.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move notes to spiketrains parameter docstring


Examples
Expand All @@ -328,29 +338,46 @@ def fanofactor(spiketrains, warn_tolerance=0.1 * pq.ms):
0.07142857142857142

"""
# Build array of spike counts (one per spike train)
spike_counts = np.array([len(st) for st in spiketrains])

# Compute FF
if all(count == 0 for count in spike_counts):
# empty list of spiketrains reaches this branch, and NaN is returned
return np.nan

if all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
if not is_time_quantity(warn_tolerance):
raise TypeError("'warn_tolerance' must be a time quantity.")
durations = [(st.t_stop - st.t_start).simplified.item()
for st in spiketrains]
durations_min = min(durations)
durations_max = max(durations)
if durations_max - durations_min > warn_tolerance.simplified.item():
warnings.warn("Fano factor calculated for spike trains of "
"different duration (minimum: {_min}s, maximum "
"{_max}s).".format(_min=durations_min,
_max=durations_max))

fano = spike_counts.var() / spike_counts.mean()
return fano
# Check if parameters are of the correct type
if not isinstance(pool_trials, bool):
raise TypeError(f"'pool_trials' must be of type bool, but got {type(pool_trials)}")
elif not is_time_quantity(warn_tolerance):
raise TypeError(f"'warn_tolerance' must be a time quantity, but got {type(warn_tolerance)}")

def _check_input_spiketrains_durations(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity],
List[np.ndarray]]) -> None:
if spiketrains and all(isinstance(st, neo.SpikeTrain) for st in spiketrains):
durations = np.array(tuple(st.duration for st in spiketrains))
if np.max(durations) - np.min(durations) > warn_tolerance:
warnings.warn(f"Fano factor calculated for spike trains of "
f"different duration (minimum: {np.min(durations)}s, maximum "
f"{np.max(durations)}s).")

def _compute_fano(spiketrains: Union[List[neo.SpikeTrain], List[pq.Quantity], List[np.ndarray]]) -> float:
# Check spike train durations
_check_input_spiketrains_durations(spiketrains)
# Build array of spike counts (one per spike train)
spike_counts = np.array(tuple(len(st) for st in spiketrains))
# Compute FF
if np.all(np.array(spike_counts) == 0):
# empty list of spiketrains reaches this branch, and NaN is returned
return np.nan
else:
return spike_counts.var()/spike_counts.mean()

if isinstance(spiketrains, elephant.trials.Trials):
if not pool_trials:
return [[_compute_fano([spiketrain]) for spiketrain in spiketrains.get_spiketrains_from_trial_as_list(idx)]
for idx in range(spiketrains.n_trials)]
elif pool_trials:
list_of_lists_of_spiketrains = [
spiketrains.get_spiketrains_from_trial_as_list(trial_id=trial_no)
for trial_no in range(spiketrains.n_trials)]
return [_compute_fano([list_of_lists_of_spiketrains[trial_no][st_no]
for trial_no in range(len(list_of_lists_of_spiketrains))])
for st_no in range(len(list_of_lists_of_spiketrains[0]))]
else: # Legacy behavior
return _compute_fano(spiketrains)


def __variation_check(v, with_nan):
Expand Down
44 changes: 31 additions & 13 deletions elephant/test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from elephant import statistics
from elephant.spike_train_generation import StationaryPoissonProcess
from elephant.test.test_trials import _create_trials_block
from elephant.trials import TrialsFromBlock
from elephant.trials import TrialsFromBlock, TrialsFromLists


class IsiTestCase(unittest.TestCase):
Expand Down Expand Up @@ -269,32 +269,34 @@ def test_mean_firing_rate_with_plain_array_and_units_start_stop_typeerror(


class FanoFactorTestCase(unittest.TestCase):
def setUp(self):
@classmethod
def setUpClass(cls):
np.random.seed(100)
num_st = 300
self.test_spiketrains = []
self.test_array = []
self.test_quantity = []
self.test_list = []
self.sp_counts = np.zeros(num_st)
cls.test_spiketrains = []
cls.test_array = []
cls.test_quantity = []
cls.test_list = []
cls.sp_counts = np.zeros(num_st)
for i in range(num_st):
r = np.random.rand(np.random.randint(20) + 1)
st = neo.core.SpikeTrain(r * pq.ms,
t_start=0.0 * pq.ms,
t_stop=20.0 * pq.ms)
self.test_spiketrains.append(st)
self.test_array.append(r)
self.test_quantity.append(r * pq.ms)
self.test_list.append(list(r))
cls.test_spiketrains.append(st)
cls.test_array.append(r)
cls.test_quantity.append(r * pq.ms)
cls.test_list.append(list(r))
# for cross-validation
self.sp_counts[i] = len(st)
cls.sp_counts[i] = len(st)

cls.test_trials = TrialsFromLists([cls.test_spiketrains, cls.test_spiketrains])

def test_fanofactor_spiketrains(self):
# Test with list of spiketrains
self.assertEqual(
np.var(self.sp_counts) / np.mean(self.sp_counts),
statistics.fanofactor(self.test_spiketrains))

# One spiketrain in list
st = self.test_spiketrains[0]
self.assertEqual(statistics.fanofactor([st]), 0.0)
Expand Down Expand Up @@ -352,6 +354,22 @@ def test_fanofactor_wrong_type(self):
self.assertRaises(TypeError, statistics.fanofactor, [st1],
warn_tolerance=1e-4)

def test_fanofactor_trials_pool_trials(self):
results = statistics.fanofactor(self.test_trials, pool_trials=True)
self.assertEqual(len(results), self.test_trials.n_spiketrains_trial_by_trial[0])

def test_fanofactor_trials_pool_trials_false(self):
results = statistics.fanofactor(self.test_trials, pool_trials=False)
self.assertEqual(len(results), self.test_trials.n_trials)
for result in results:
self.assertEqual(len(result), self.test_trials.n_spiketrains_trial_by_trial[0])

def test_fanofactor_trials_pool_spike_trains_wrong_type(self):
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trains="Wrong Type")
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type")
self.assertRaises(TypeError, statistics.fanofactor, self.test_trials, pool_spike_trials="Wrong Type",
pool_spike_trains="Wrong Type")


class LVTestCase(unittest.TestCase):
def setUp(self):
Expand Down
Loading