diff --git a/elephant/statistics.py b/elephant/statistics.py index 45d9cd283..191aad4b9 100644 --- a/elephant/statistics.py +++ b/elephant/statistics.py @@ -81,7 +81,7 @@ import elephant.trials from elephant.conversion import BinnedSpikeTrain from elephant.utils import deprecated_alias, check_neo_consistency, \ - is_time_quantity, round_binning_errors + is_time_quantity, round_binning_errors, is_list_spiketrains # do not import unicode_literals # (quantities rescale does not work with unicodes) @@ -613,7 +613,7 @@ def instantaneous_rate(spiketrains, sampling_period, kernel='auto', Parameters ---------- - spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials # noqa + spiketrains : neo.SpikeTrain, list of neo.SpikeTrain or elephant.trials.Trials Input spike train(s) for which the instantaneous firing rate is calculated. If a list of spike trains is supplied, the parameter pool_spike_trains determines the behavior of the function. If a Trials @@ -1031,8 +1031,7 @@ def optimal_kernel(st): sigma=str(kernel.sigma), invert=kernel.invert) - if isinstance(spiketrains, neo.core.spiketrainlist.SpikeTrainList) and ( - pool_spike_trains): + if is_list_spiketrains(spiketrains) and (pool_spike_trains): rate = np.mean(rate, axis=1) rate = neo.AnalogSignal(signal=rate, diff --git a/elephant/test/test_statistics.py b/elephant/test/test_statistics.py index 426111810..7d6c8825f 100644 --- a/elephant/test/test_statistics.py +++ b/elephant/test/test_statistics.py @@ -482,7 +482,7 @@ def test_cv2_raise_error(self): self.assertRaises(ValueError, statistics.cv2, np.array([seq, seq])) -class InstantaneousRateTest(unittest.TestCase): +class InstantaneousRateTestCase(unittest.TestCase): @classmethod def setUpClass(cls) -> None: @@ -490,7 +490,7 @@ def setUpClass(cls) -> None: Run once before tests: """ - block = _create_trials_block(n_trials=36) + block = _create_trials_block(n_trials=36, n_spiketrains=5) cls.block = block cls.trial_object = TrialsFromBlock(block, description='trials are segments') @@ -988,8 +988,44 @@ def test_instantaneous_rate_trials_pool_trials(self): pool_spike_trains=False, pool_trials=True) self.assertIsInstance(rate, neo.core.AnalogSignal) + self.assertEqual(rate.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0]) - def test_instantaneous_rate_list_pool_spike_trains(self): + def test_instantaneous_rate_trials_pool_spiketrains(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate(self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=True, + pool_trials=False) + self.assertIsInstance(rate, list) + self.assertEqual(len(rate), self.trial_object.n_trials) + self.assertEqual(rate[0].shape[1], 1) + + def test_instantaneous_rate_trials_pool_spiketrains_pool_trials(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate(self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=True, + pool_trials=True) + self.assertIsInstance(rate, neo.AnalogSignal) + self.assertEqual(rate.shape[1], 1) + + def test_instantaneous_rate_trials_pool_spiketrains_false_pool_trials_false(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate(self.trial_object, + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=False, + pool_trials=False) + self.assertIsInstance(rate, list) + self.assertEqual(len(rate), self.trial_object.n_trials) + self.assertEqual(rate[0].shape[1], self.trial_object.n_spiketrains_trial_by_trial[0]) + + def test_instantaneous_rate_spiketrainlist_pool_spike_trains(self): kernel = kernels.GaussianKernel(sigma=500 * pq.ms) rate = statistics.instantaneous_rate( @@ -999,7 +1035,19 @@ def test_instantaneous_rate_list_pool_spike_trains(self): pool_spike_trains=True, pool_trials=False) self.assertIsInstance(rate, neo.core.AnalogSignal) - self.assertEqual(rate.magnitude.shape[1], 1) + self.assertEqual(rate.shape[1], 1) + + def test_instantaneous_rate_list_pool_spike_trains(self): + kernel = kernels.GaussianKernel(sigma=500 * pq.ms) + + rate = statistics.instantaneous_rate( + list(self.trial_object.get_spiketrains_from_trial_as_list(0)), + sampling_period=0.1 * pq.ms, + kernel=kernel, + pool_spike_trains=True, + pool_trials=False) + self.assertIsInstance(rate, neo.core.AnalogSignal) + self.assertEqual(rate.shape[1], 1) def test_instantaneous_rate_list_of_spike_trains(self): kernel = kernels.GaussianKernel(sigma=500 * pq.ms) @@ -1010,7 +1058,7 @@ def test_instantaneous_rate_list_of_spike_trains(self): pool_spike_trains=False, pool_trials=False) self.assertIsInstance(rate, neo.core.AnalogSignal) - self.assertEqual(rate.magnitude.shape[1], 2) + self.assertEqual(rate.magnitude.shape[1], self.trial_object.n_spiketrains_trial_by_trial[0]) class TimeHistogramTestCase(unittest.TestCase): diff --git a/elephant/test/test_utils.py b/elephant/test/test_utils.py index cc927d53e..709955c4b 100644 --- a/elephant/test/test_utils.py +++ b/elephant/test/test_utils.py @@ -136,5 +136,42 @@ def test_decorator_return_with_list_of_lists_input_as_kwarg(self): self.assertIsInstance(spiketrain, SpikeTrain) +class TestIsListNeoSpiketrains(unittest.TestCase): + def setUp(self): + # Set up common test spiketrains. + self.spiketrain1 = neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=4 * pq.s) + self.spiketrain2 = neo.SpikeTrain([2, 3, 4] * pq.s, t_stop=5 * pq.s) + + def test_valid_list_input(self): + valid_list = [self.spiketrain1, self.spiketrain2] + self.assertTrue(utils.is_list_spiketrains(valid_list)) + + def test_valid_tuple_input(self): + valid_tuple = (self.spiketrain1, self.spiketrain2) + self.assertTrue(utils.is_list_spiketrains(valid_tuple)) + + def test_valid_spiketrainlist_input(self): + valid_spiketrainlist = neo.core.spiketrainlist.SpikeTrainList(items=(self.spiketrain1, self.spiketrain2)) + self.assertTrue(utils.is_list_spiketrains(valid_spiketrainlist)) + + def test_non_iterable_input(self): + self.assertFalse(utils.is_list_spiketrains(42)) + + def test_non_spiketrain_objects(self): + invalid_list = [self.spiketrain1, "not a spiketrain"] + self.assertFalse(utils.is_list_spiketrains(invalid_list)) + + def test_mixed_types_input(self): + invalid_mixed = [self.spiketrain1, 42, self.spiketrain2] + self.assertFalse(utils.is_list_spiketrains(invalid_mixed)) + + def test_none_input(self): + self.assertFalse(utils.is_list_spiketrains(None)) + + def test_single_spiketrain_input(self): + single_spiketrain = neo.SpikeTrain([1, 2, 3] * pq.s, t_stop=4 * pq.s) + self.assertFalse(utils.is_list_spiketrains(single_spiketrain)) + + if __name__ == '__main__': unittest.main() diff --git a/elephant/utils.py b/elephant/utils.py index b4ddfee22..906f513c4 100644 --- a/elephant/utils.py +++ b/elephant/utils.py @@ -7,6 +7,7 @@ check_neo_consistency check_same_units round_binning_errors + is_list_spiketrains """ from __future__ import division, print_function, unicode_literals @@ -21,7 +22,8 @@ import quantities as pq from elephant.trials import Trials - +import collections.abc +import neo __all__ = [ "deprecated_alias", @@ -31,6 +33,7 @@ "check_neo_consistency", "check_same_units", "round_binning_errors", + "is_list_spiketrains", ] @@ -446,3 +449,30 @@ def wrapper(*args, **kwargs): return method(*new_args, **new_kwargs) return wrapper + + +def is_list_spiketrains(obj: object) -> bool: + """ + Check if input is an iterable containing only neo.SpikeTrain objects. + + Parameters + ---------- + obj : object + The object to check. + + Returns + ------- + bool + True if obj is an iterable containing only neo.SpikeTrain objects. + + """ + + if not isinstance(obj, collections.abc.Iterable): + # Input must be an iterable (list, tuple, etc.) + return False + + if not all(isinstance(st, neo.SpikeTrain) for st in obj): + # All elements must be neo.SpikeTrain objects + return False + + return True