diff --git a/hnn_core/cell_response.py b/hnn_core/cell_response.py index 248eca559..658de45c4 100644 --- a/hnn_core/cell_response.py +++ b/hnn_core/cell_response.py @@ -153,6 +153,24 @@ def __eq__(self, other): def spike_times(self): return self._spike_times + @property + def cell_types(self): + """Get unique cell types.""" + spike_types_data = np.concatenate(np.array(self.spike_types, dtype=object)) + return np.unique(spike_types_data).tolist() + + @property + def spike_times_by_type(self): + """Get a dictionary of spike times by cell type""" + spike_times = dict() + for cell_type in self.cell_types: + spike_times[cell_type] = list() + for trial_spike_times, trial_spike_types in zip(self.spike_times, self.spike_types): + mask = np.isin(trial_spike_types, cell_type) + trial_cell_spike_times = np.array(trial_spike_times)[mask].tolist() + spike_times[cell_type].append(trial_cell_spike_times) + return spike_times + @property def spike_gids(self): return self._spike_gids diff --git a/hnn_core/tests/test_cell_response.py b/hnn_core/tests/test_cell_response.py index 22a77f1af..b1d243109 100644 --- a/hnn_core/tests/test_cell_response.py +++ b/hnn_core/tests/test_cell_response.py @@ -24,6 +24,11 @@ def test_cell_response(tmp_path): spike_gids=spike_gids, spike_types=spike_types, times=sim_times) + + assert set(cell_response.cell_types) == set(gid_ranges.keys()) + assert cell_response.spike_times_by_type['L2_basket'] == [[7.89], []] + assert cell_response.spike_times_by_type['L5_pyramidal'] == [[], [4.2812]] + kwargs_hist = dict(alpha=0.25) fig = cell_response.plot_spikes_hist(show=False, **kwargs_hist) assert all(patch.get_alpha() == kwargs_hist['alpha']