From a712ccb470405d82b104a48db2a2159202fc9c4c Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Thu, 18 May 2023 20:58:20 -0400 Subject: [PATCH 01/30] Add cell position argument ot cell plots --- hnn_core/cell.py | 10 +++++++--- hnn_core/viz.py | 22 ++++++++++++++++------ 2 files changed, 23 insertions(+), 9 deletions(-) diff --git a/hnn_core/cell.py b/hnn_core/cell.py index 8ac30a330..9b49efea9 100644 --- a/hnn_core/cell.py +++ b/hnn_core/cell.py @@ -863,7 +863,7 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn, return nc - def plot_morphology(self, ax=None, color=None, show=True): + def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), show=True): """Plot the cell morphology. Parameters @@ -875,8 +875,11 @@ def plot_morphology(self, ax=None, color=None, show=True): color indicated by str. If dict, colors of individual sections can be specified. Must have a key for every section in cell as defined in the `Cell.sections` attribute. - | Ex: ``{'apical_trunk': 'r', 'soma': 'b', ...}`` + | Ex: ``{'apical_trunk': 'r', 'soma': 'b', ...}`` + pos : tuple of int or float | None + Position of cell soma. Must be a tuple of 3 elements for the + (x, y, z) position of the soma in 3D space. Default: (0, 0, 0) show : bool If True, show the plot @@ -885,7 +888,8 @@ def plot_morphology(self, ax=None, color=None, show=True): axes : instance of Axes3D The matplotlib 3D axis handle. """ - return plot_cell_morphology(self, ax=ax, color=color, show=show) + return plot_cell_morphology(self, ax=ax, color=color, pos=pos, + show=show) def _update_section_end_pts_L(self, node, dpt): if self.cell_tree is None: diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 16e3e8c54..e9dddb5de 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -853,7 +853,7 @@ def _linewidth_from_data_units(ax, linewidth): return linewidth * (length / value_range) -def plot_cell_morphology(cell, ax, color=None, show=True): +def plot_cell_morphology(cell, ax, color=None, pos=(0, 0, 0), show=True): """Plot the cell morphology. Parameters @@ -871,6 +871,9 @@ def plot_cell_morphology(cell, ax, color=None, show=True): defined in the `Cell.sections` attribute. | Ex: ``{'apical_trunk': 'r', 'soma': 'b', ...}`` + pos : tuple of int or float | None + Position of cell soma. Must be a tuple of 3 elements for the + (x, y, z) position of the soma in 3D space. Default: (0, 0, 0) Returns ------- @@ -893,18 +896,25 @@ def plot_cell_morphology(cell, ax, color=None, show=True): if isinstance(color, dict): section_colors = color + _validate_type(pos, tuple, 'pos') + if isinstance(pos, tuple): + if len(pos) != 3: + raise ValueError('pos must be a tuple of 3 elements') + for pos_idx in pos: + _validate_type(pos_idx, (float, int), 'pos[idx]') + # Cell is in XZ plane - ax.set_xlim((cell.pos[1] - 250, cell.pos[1] + 150)) - ax.set_zlim((cell.pos[2] - 100, cell.pos[2] + 1200)) + # ax.set_xlim((pos[1] - 250, pos[1] + 150)) + # ax.set_zlim((pos[2] - 100, pos[2] + 1200)) for sec_name, section in cell.sections.items(): linewidth = _linewidth_from_data_units(ax, section.diam) end_pts = section.end_pts xs, ys, zs = list(), list(), list() for pt in end_pts: - dx = cell.pos[0] - cell.sections['soma'].end_pts[0][0] - dy = cell.pos[1] - cell.sections['soma'].end_pts[0][1] - dz = cell.pos[2] - cell.sections['soma'].end_pts[0][2] + dx = pos[0] - cell.sections['soma'].end_pts[0][0] + dy = pos[1] - cell.sections['soma'].end_pts[0][1] + dz = pos[2] - cell.sections['soma'].end_pts[0][2] xs.append(pt[0] + dx) ys.append(pt[1] + dz) zs.append(pt[2] + dy) From 4e09d90b122829709c50ba5c4e69ea0532db503a Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Thu, 18 May 2023 20:58:44 -0400 Subject: [PATCH 02/30] Add position tests --- hnn_core/tests/test_viz.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index b26754c5a..dab93fe44 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -84,6 +84,13 @@ def test_network_visualization(): with pytest.raises(TypeError, match="'ax' to be an instance of Axes3D, but got Axes"): plot_cells(net, ax=axes, show=False) + cell_type.plot_morphology(pos=(1.0, 2.0, 3.0)) + with pytest.raises(TypeError, match='pos must be'): + cell_type.plot_morphology(pos=123) + with pytest.raises(ValueError, match='pos must be a tuple of 3 elements'): + cell_type.plot_morphology(pos=(1, 2, 3, 4)) + with pytest.raises(TypeError, match='pos\\[idx\\] must be'): + cell_type.plot_morphology(pos=(1, '2', 3)) plt.close('all') From d55a831e0753e8c2b88e3f4e7c75235ac62c0056 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sat, 20 May 2023 17:15:12 -0400 Subject: [PATCH 03/30] WIP --- hnn_core/cell.py | 6 ++++-- hnn_core/viz.py | 8 +++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/hnn_core/cell.py b/hnn_core/cell.py index 9b49efea9..cb0d488b0 100644 --- a/hnn_core/cell.py +++ b/hnn_core/cell.py @@ -863,7 +863,9 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn, return nc - def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), show=True): + def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), + xlim=(-250, 150), ylim=None, zlim=(-100, 1200), + show=True): """Plot the cell morphology. Parameters @@ -889,7 +891,7 @@ def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), show=True): The matplotlib 3D axis handle. """ return plot_cell_morphology(self, ax=ax, color=color, pos=pos, - show=show) + xlim=xlim, ylim=ylim, zlim=zlim, show=show) def _update_section_end_pts_L(self, node, dpt): if self.cell_tree is None: diff --git a/hnn_core/viz.py b/hnn_core/viz.py index e9dddb5de..e1ad47f50 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -853,7 +853,8 @@ def _linewidth_from_data_units(ax, linewidth): return linewidth * (length / value_range) -def plot_cell_morphology(cell, ax, color=None, pos=(0, 0, 0), show=True): +def plot_cell_morphology(cell, ax, color=None, pos=(0, 0, 0), xlim=(-250, 150), + ylim=None, zlim=(-100, 1200), show=True): """Plot the cell morphology. Parameters @@ -904,8 +905,9 @@ def plot_cell_morphology(cell, ax, color=None, pos=(0, 0, 0), show=True): _validate_type(pos_idx, (float, int), 'pos[idx]') # Cell is in XZ plane - # ax.set_xlim((pos[1] - 250, pos[1] + 150)) - # ax.set_zlim((pos[2] - 100, pos[2] + 1200)) + ax.set_xlim((pos[0] - xlim[0], pos[0] + xlim[1])) + ax.set_ylim((pos[1] - ylim[0], pos[1] + ylim[1])) + ax.set_zlim((pos[2] - zlim[0], pos[2] + zlim[1])) for sec_name, section in cell.sections.items(): linewidth = _linewidth_from_data_units(ax, section.diam) From e7ab8ff9e9c117d3491c1e46077a0ebe040c87ab Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sun, 21 May 2023 22:27:46 -0400 Subject: [PATCH 04/30] start notebook --- examples/howto/plot_hnn_animation.py | 99 ++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) create mode 100644 examples/howto/plot_hnn_animation.py diff --git a/examples/howto/plot_hnn_animation.py b/examples/howto/plot_hnn_animation.py new file mode 100644 index 000000000..fb8d395f6 --- /dev/null +++ b/examples/howto/plot_hnn_animation.py @@ -0,0 +1,99 @@ +""" +================================ +XX. Modifying local connectivity +================================ + +This example demonstrates how to animate HNN simulations +""" + +# Author: Nick Tolley + +# sphinx_gallery_thumbnail_number = 2 + + +############################################################################### +def plot_network(net, ax, t_idx, colormap): + """ + colormap : str + The name of a matplotlib colormap. Default: 'viridis' + """ + + if ax is None: + ax = plt.axes(projection='3d') + + xlim = (-200, 3100) + ylim = (-200, 3100) + # ylim = (-3000, 3100) + + zlim = (-300, 2200) + #viridis = cm.get_cmap('viridis', 8) + + for cell_type in net.cell_types: + gid_range = net.gid_ranges[cell_type] + for gid_idx, gid in enumerate(gid_range): + print(gid, end=' ') + + cell = net.cell_types[cell_type] + # vsec = {sec_name: ((np.array(net.cell_response.vsec[0][gid][ + # sec_name]) - vmin) / (vmax - vmin)) for + # sec_name in cell.sections.keys()} + # section_colors = {sec_name: viridis(vsec[sec_name][t_idx]) for + # sec_name in cell.sections.keys()} + + section_colors = 'C0' + + pos = net.pos_dict[cell_type][gid_idx] + pos = (float(pos[0]), float(pos[2]), float(pos[1])) + # plot_cell_morphology( + # cell, ax=ax, show=False, pos=pos, + # xlim=xlim, ylim=ylim, zlim=zlim, color=section_colors) + cell.plot_morphology(ax=ax, show=False, color=section_colors, + pos=pos, xlim=xlim, ylim=ylim, zlim=zlim) + # ax.view_init(10, -100) + ax.view_init(10, -500) + + ax.set_xlim(xlim) + ax.set_ylim(ylim) + ax.set_zlim(zlim) + # ax.axis('on') + + return ax + +def get_colors(net, t_idx, colormap): + color_list = list() + for cell_type in net.cell_types: + gid_range = net.gid_ranges[cell_type] + for gid_idx, gid in enumerate(gid_range): + + cell = net.cell_types[cell_type] + vmin, vmax = -100, 50 + + for sec_name in cell.sections.keys(): + vsec = (np.array(net.cell_response.vsec[0][gid][sec_name]) - vmin) / (vmax - vmin) + color_list.append(colormap(vsec[t_idx])) + return color_list + + + +def update_colors(ax, net, t_idx, colormap): + color_list = get_colors(net, t_idx, colormap) + lines = ax.get_lines() + for line, color in zip(lines, color_list): + line.set_color(color) + ax.view_init(10, -500) + + +net = jones_2009_model() +net.set_cell_positions(inplane_distance=300) +add_erp_drives_to_jones_model(net) +dpl = simulate_dipole(net, dt=0.5, tstop=100, record_vsec='all') + + +fig = plt.figure() +ax = fig.add_subplot(projection='3d') +plot_network(net, ax=ax, t_idx=None, colormap=None) + +colormap = colormaps['viridis'] +update_colors(ax, net, t_idx=100, colormap=colormap) +ax.view_init(20, 100) +fig From d058453bee67e5c42b47ec1d481e55d641afeb82 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Mon, 22 May 2023 14:28:39 -0400 Subject: [PATCH 05/30] Start networkplot class --- hnn_core/viz.py | 90 +++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 90 insertions(+) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index e1ad47f50..c1a7684b4 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1240,3 +1240,93 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, plt_show(show) return ax.get_figure() + + +class NetworkPlot: + """Helper class to visualize full + morphology of HNN model. + + Parameters + ---------- + net : Instance of Network object + The Network object + + colormap_name : str + Name of colormap used to visualize membrane potential + + Attributes + ---------- + + """ + def __init__(self, net, ax=None, vmin=-100, vmax=50, default_color='b', + voltage_colormap='viridis', elev=10, azim=-500, + xlim=(-200, 3100), ylim=(-200, 300), zlim=(-300, 2200), + trial_idx=0): + import matplotlib.pyplot as plt + from matplotlib import colormaps + + self.net = net + + self.vmin = vmin + self.vmax = vmax + + self.default_color = default_color + self.voltage_colormap = voltage_colormap + self.colormap = colormaps[voltage_colormap] + + # Axes limits and view positions + self.xlim = xlim + self.ylim = ylim + self.zlim = zlim + self.elev = elev + self.azim = azim + + self.trial_idx = trial_idx + + # Get voltage data and corresponding colors + self.vsec_array = self.get_voltages() + self.color_array = self.colormap(self.vsec_array) + + # Create figure + if ax is None: + self.fig = plt.figure() + self.ax = self.fig.add_subplot(projection='3d') + else: + self.fig=None + self.init_network_plot() + + def get_voltages(self): + vsec_list = list() + for cell_type in self.net.cell_types: + gid_range = self.net.gid_ranges[cell_type] + for gid in gid_range: + + cell = self.net.cell_types[cell_type] + + for sec_name in cell.sections.keys(): + vsec = np.array(self.net.cell_response.vsec[self.trial_idx][gid][sec_name]) + vsec_list.append(vsec) + + vsec_array = np.vstack(vsec_list) + vsec_array = (vsec_array - self.vmin) / (self.vmax - self.vmin) + return np.vstack(vsec_list) + + def update_section_voltages(self, lines, color_list): + return + + def plot_voltage(self, t_idx): + return + + def init_network_plot(self): + for cell_type in self.net.cell_types: + gid_range = self.net.gid_ranges[cell_type] + for gid_idx, gid in enumerate(gid_range): + + cell = self.net.cell_types[cell_type] + + pos = self.net.pos_dict[cell_type][gid_idx] + pos = (float(pos[0]), float(pos[2]), float(pos[1])) + + cell.plot_morphology(ax=self.ax, show=False, color=self.default_color, + pos=pos, xlim=self.xlim, ylim=self.ylim, zlim=self.zlim) + From 3693b0da0c252f436fc5068776018f8c98c7f14b Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Mon, 22 May 2023 17:38:38 -0400 Subject: [PATCH 06/30] Functioning NetworkPlot class --- hnn_core/viz.py | 207 ++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 174 insertions(+), 33 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index c1a7684b4..bcc429a9f 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1250,38 +1250,62 @@ class NetworkPlot: ---------- net : Instance of Network object The Network object - - colormap_name : str - Name of colormap used to visualize membrane potential - - Attributes - ---------- - + ax : instance of matplotlib Axes3D | None + An axis object from matplotlib. If None, + a new figure is created. + vmin : int | float + Lower limit of colormap for plotting voltage + Default: -100 mV + vmax : int | float + Upper limit of colormap for plotting voltage + Default: 50 mV + bg_color : str + Background color of ax. Default: 'black' + voltage_colormap : str + Colormap used for plotting voltages + Default: 'viridis' + elev : int | float + Elevation 3D plot viewpoint, default: 10 + azim : int | float + Azimuth of 3D plot view point, default: 20 + xlim : tuple of int | tuple of float + x limits of plot window. Default (-200, 3100) + ylim : tuple of int | tuple of float + y limits of plot window. Default (-200, 3100) + zlim : tuple of int | tuple of float + z limits of plot window. Default (-300, 2200) + trial_idx : int + Index of simulation trial plotted. Default: 0 + time_idx : int + Index of time point plotted. Default: 0 """ - def __init__(self, net, ax=None, vmin=-100, vmax=50, default_color='b', + def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', voltage_colormap='viridis', elev=10, azim=-500, - xlim=(-200, 3100), ylim=(-200, 300), zlim=(-300, 2200), - trial_idx=0): + xlim=(-200, 3100), ylim=(-200, 3100), zlim=(-300, 2200), + trial_idx=0, time_idx=0): import matplotlib.pyplot as plt from matplotlib import colormaps - self.net = net + self.times = net.cell_response.times + + self._vmin = vmin + self._vmax = vmax - self.vmin = vmin - self.vmax = vmax + self._bg_color = bg_color + self._voltage_colormap = voltage_colormap - self.default_color = default_color - self.voltage_colormap = voltage_colormap + self.colormaps = colormaps # Saved for voltage_colormap update method self.colormap = colormaps[voltage_colormap] # Axes limits and view positions - self.xlim = xlim - self.ylim = ylim - self.zlim = zlim - self.elev = elev - self.azim = azim + self._xlim = xlim + self._ylim = ylim + self._zlim = zlim + self._elev = elev + self._azim = azim - self.trial_idx = trial_idx + self._trial_idx = trial_idx + self._time_idx = time_idx # Get voltage data and corresponding colors self.vsec_array = self.get_voltages() @@ -1291,9 +1315,11 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, default_color='b', if ax is None: self.fig = plt.figure() self.ax = self.fig.add_subplot(projection='3d') + self.ax.set_facecolor(self._bg_color) else: - self.fig=None + self.fig = None self.init_network_plot() + self._update_axes() def get_voltages(self): vsec_list = list() @@ -1304,19 +1330,19 @@ def get_voltages(self): cell = self.net.cell_types[cell_type] for sec_name in cell.sections.keys(): - vsec = np.array(self.net.cell_response.vsec[self.trial_idx][gid][sec_name]) + vsec = np.array(self.net.cell_response.vsec[ + self.trial_idx][gid][sec_name]) vsec_list.append(vsec) - + vsec_array = np.vstack(vsec_list) vsec_array = (vsec_array - self.vmin) / (self.vmax - self.vmin) return np.vstack(vsec_list) - - def update_section_voltages(self, lines, color_list): - return - - def plot_voltage(self, t_idx): - return - + + def update_section_voltages(self, t_idx): + color_list = self.color_array[:, t_idx] + for line, color in zip(self.ax.lines, color_list): + line.set_color(color) + def init_network_plot(self): for cell_type in self.net.cell_types: gid_range = self.net.gid_ranges[cell_type] @@ -1327,6 +1353,121 @@ def init_network_plot(self): pos = self.net.pos_dict[cell_type][gid_idx] pos = (float(pos[0]), float(pos[2]), float(pos[1])) - cell.plot_morphology(ax=self.ax, show=False, color=self.default_color, - pos=pos, xlim=self.xlim, ylim=self.ylim, zlim=self.zlim) + cell.plot_morphology(ax=self.ax, show=False, + pos=pos, xlim=self.xlim, + ylim=self.ylim, zlim=self.zlim) + + def _update_axes(self): + self.ax.set_xlim(self._xlim) + self.ax.set_ylim(self._ylim) + self.ax.set_zlim(self._zlim) + + self.ax.view_init(self._elev, self._azim) + + # Axis limits + @property + def xlim(self): + return self._xlim + + @xlim.setter + def xlim(self, xlim): + self._xlim = xlim + self.ax.set_xlim(self._xlim) + + @property + def ylim(self): + return self._ylim + + @ylim.setter + def ylim(self, ylim): + self._ylim = ylim + self.ax.set_ylim(self._ylim) + + @property + def zlim(self): + return self._zlim + + @zlim.setter + def zlim(self, zlim): + self._zlim = zlim + self.ax.set_zlim(self._zlim) + + # Eleevation and azimuth of 3D viewpoint + @property + def elev(self): + return self._elev + + @elev.setter + def elev(self, elev): + self._elev = elev + self.ax.view_init(self._elev, self._azim) + + @property + def azim(self): + return self._azim + + @azim.setter + def azim(self, azim): + self._azim = azim + self.ax.view_init(self._elev, self._azim) + + # Minimum and maximum voltages + @property + def vmin(self): + return self._vmin + + @vmin.setter + def vmin(self, vmin): + self._vmin = vmin + self.vsec_array = self.get_voltages() + self.color_array = self.colormap(self.vsec_array) + + @property + def vmax(self): + return self._vmax + + @vmax.setter + def vmax(self, vmax): + self._vmax = vmax + self.vsec_array = self.get_voltages() + self.color_array = self.colormap(self.vsec_array) + # Time and trial indices + @property + def trial_idx(self): + return self._trial_idx + + @trial_idx.setter + def trial_idx(self, trial_idx): + self._trial_idx = trial_idx + self.vsec_array = self.get_voltages() + self.color_array = self.colormap(self.vsec_array) + + @property + def time_idx(self): + return self._time_idx + + @time_idx.setter + def time_idx(self, time_idx): + self._time_idx = time_idx + self.update_section_voltages(self._time_idx) + + # Background color and voltage colormaps + @property + def bg_color(self): + return self._bg_color + + @bg_color.setter + def bg_color(self, bg_color): + self._bg_color = bg_color + self.ax.set_facecolor(self._bg_color) + + @property + def voltage_colormap(self): + return self._voltage_colormap + + @voltage_colormap.setter + def voltage_colormap(self, voltage_colormap): + self._voltage_colormap = voltage_colormap + self.colormap = self.colormaps[self._voltage_colormap] + self.color_array = self.colormap(self.vsec_array) From c88bcbb8f57669d04f05b0d028603dad0bf96d65 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Mon, 22 May 2023 17:56:21 -0400 Subject: [PATCH 07/30] Update demo code --- examples/howto/plot_hnn_animation.py | 98 +++++----------------------- 1 file changed, 18 insertions(+), 80 deletions(-) diff --git a/examples/howto/plot_hnn_animation.py b/examples/howto/plot_hnn_animation.py index fb8d395f6..1b13f0402 100644 --- a/examples/howto/plot_hnn_animation.py +++ b/examples/howto/plot_hnn_animation.py @@ -8,92 +8,30 @@ # Author: Nick Tolley -# sphinx_gallery_thumbnail_number = 2 - ############################################################################### -def plot_network(net, ax, t_idx, colormap): - """ - colormap : str - The name of a matplotlib colormap. Default: 'viridis' - """ - - if ax is None: - ax = plt.axes(projection='3d') - - xlim = (-200, 3100) - ylim = (-200, 3100) - # ylim = (-3000, 3100) - - zlim = (-300, 2200) - #viridis = cm.get_cmap('viridis', 8) - - for cell_type in net.cell_types: - gid_range = net.gid_ranges[cell_type] - for gid_idx, gid in enumerate(gid_range): - print(gid, end=' ') - - cell = net.cell_types[cell_type] - # vsec = {sec_name: ((np.array(net.cell_response.vsec[0][gid][ - # sec_name]) - vmin) / (vmax - vmin)) for - # sec_name in cell.sections.keys()} - # section_colors = {sec_name: viridis(vsec[sec_name][t_idx]) for - # sec_name in cell.sections.keys()} - - section_colors = 'C0' - - pos = net.pos_dict[cell_type][gid_idx] - pos = (float(pos[0]), float(pos[2]), float(pos[1])) - # plot_cell_morphology( - # cell, ax=ax, show=False, pos=pos, - # xlim=xlim, ylim=ylim, zlim=zlim, color=section_colors) - cell.plot_morphology(ax=ax, show=False, color=section_colors, - pos=pos, xlim=xlim, ylim=ylim, zlim=zlim) - # ax.view_init(10, -100) - ax.view_init(10, -500) - - ax.set_xlim(xlim) - ax.set_ylim(ylim) - ax.set_zlim(zlim) - # ax.axis('on') - - return ax - -def get_colors(net, t_idx, colormap): - color_list = list() - for cell_type in net.cell_types: - gid_range = net.gid_ranges[cell_type] - for gid_idx, gid in enumerate(gid_range): - - cell = net.cell_types[cell_type] - vmin, vmax = -100, 50 - - for sec_name in cell.sections.keys(): - vsec = (np.array(net.cell_response.vsec[0][gid][sec_name]) - vmin) / (vmax - vmin) - color_list.append(colormap(vsec[t_idx])) - return color_list - - - -def update_colors(ax, net, t_idx, colormap): - color_list = get_colors(net, t_idx, colormap) - lines = ax.get_lines() - for line, color in zip(lines, color_list): - line.set_color(color) - ax.view_init(10, -500) - +from hnn_core import jones_2009_model, simulate_dipole +from hnn_core.network_models import add_erp_drives_to_jones_model +from hnn_core.viz import NetworkPlot net = jones_2009_model() net.set_cell_positions(inplane_distance=300) add_erp_drives_to_jones_model(net) -dpl = simulate_dipole(net, dt=0.5, tstop=100, record_vsec='all') +dpl = simulate_dipole(net, dt=0.5, tstop=170, record_vsec='all') + +net_plot = NetworkPlot(net) + +############################################################################### +from ipywidgets import interact, IntSlider +def update_plot(t_idx, elev, azim): + net_plot.update_section_voltages(t_idx) + net_plot.elev = elev + net_plot.azim = azim + return net_plot.fig -fig = plt.figure() -ax = fig.add_subplot(projection='3d') -plot_network(net, ax=ax, t_idx=None, colormap=None) +time_slider = IntSlider(min=0, max=len(net_plot.times), value=1, continuous_update=False) +elev_slider = IntSlider(min=-100, max=100, value=10, continuous_update=False) +azim_slider = IntSlider(min=-100, max=100, value=-100, continuous_update=False) -colormap = colormaps['viridis'] -update_colors(ax, net, t_idx=100, colormap=colormap) -ax.view_init(20, 100) -fig +interact(update_plot, t_idx=time_slider, elev=elev_slider, azim=azim_slider) From 7e37564eb22e5e38b3991fb483c34b4bf1327273 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Mon, 22 May 2023 18:16:50 -0400 Subject: [PATCH 08/30] Fix cell plot test --- hnn_core/cell.py | 2 +- hnn_core/viz.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/hnn_core/cell.py b/hnn_core/cell.py index cb0d488b0..3c93682d4 100644 --- a/hnn_core/cell.py +++ b/hnn_core/cell.py @@ -864,7 +864,7 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn, return nc def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), - xlim=(-250, 150), ylim=None, zlim=(-100, 1200), + xlim=(-250, 150), ylim=(-100, 100), zlim=(-100, 1200), show=True): """Plot the cell morphology. diff --git a/hnn_core/viz.py b/hnn_core/viz.py index bcc429a9f..e6f10d54c 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -853,8 +853,9 @@ def _linewidth_from_data_units(ax, linewidth): return linewidth * (length / value_range) -def plot_cell_morphology(cell, ax, color=None, pos=(0, 0, 0), xlim=(-250, 150), - ylim=None, zlim=(-100, 1200), show=True): +def plot_cell_morphology( + cell, ax, color=None, pos=(0, 0, 0), xlim=(-250, 150), + ylim=(-100, 100), zlim=(-100, 1200), show=True): """Plot the cell morphology. Parameters @@ -905,9 +906,9 @@ def plot_cell_morphology(cell, ax, color=None, pos=(0, 0, 0), xlim=(-250, 150), _validate_type(pos_idx, (float, int), 'pos[idx]') # Cell is in XZ plane - ax.set_xlim((pos[0] - xlim[0], pos[0] + xlim[1])) - ax.set_ylim((pos[1] - ylim[0], pos[1] + ylim[1])) - ax.set_zlim((pos[2] - zlim[0], pos[2] + zlim[1])) + ax.set_xlim((pos[0] + xlim[0], pos[0] + xlim[1])) + ax.set_zlim((pos[1] + zlim[0], pos[1] + zlim[1])) + ax.set_ylim((pos[2] + ylim[0], pos[2] + ylim[1])) for sec_name, section in cell.sections.items(): linewidth = _linewidth_from_data_units(ax, section.diam) From 983929ce2484a13b50c4608bd0073527ab2b93de Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Tue, 23 May 2023 18:50:15 -0400 Subject: [PATCH 09/30] First pass at export function, fix vmin and vmax --- hnn_core/viz.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index e6f10d54c..fe408ad0e 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1337,7 +1337,7 @@ def get_voltages(self): vsec_array = np.vstack(vsec_list) vsec_array = (vsec_array - self.vmin) / (self.vmax - self.vmin) - return np.vstack(vsec_list) + return vsec_array def update_section_voltages(self, t_idx): color_list = self.color_array[:, t_idx] @@ -1365,6 +1365,15 @@ def _update_axes(self): self.ax.view_init(self._elev, self._azim) + def export_movie(self, fname, dpi=300): + import matplotlib.animation as animation + ani = animation.FuncAnimation( + self.fig, self.set_time_idx, len(self.times) - 1, interval=30) + + writer = animation.writers['ffmpeg'](fps=30) + ani.save(fname, writer=writer, dpi=dpi) + return ani + # Axis limits @property def xlim(self): @@ -1453,6 +1462,10 @@ def time_idx(self, time_idx): self._time_idx = time_idx self.update_section_voltages(self._time_idx) + # Necessary for making animations + def set_time_idx(self, time_idx): + self.time_idx = time_idx + # Background color and voltage colormaps @property def bg_color(self): From 5821004489601d1211acd955d9f39628aea59dfe Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Tue, 4 Jul 2023 19:29:11 -0400 Subject: [PATCH 10/30] Better export func --- hnn_core/viz.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index fe408ad0e..d1ed5f728 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1365,12 +1365,18 @@ def _update_axes(self): self.ax.view_init(self._elev, self._azim) - def export_movie(self, fname, dpi=300): + def export_movie(self, fname, fps=30, dpi=300, decim=10, + interval=30, frame_start=0, frame_stop=None, + writer='ffmpeg'): import matplotlib.animation as animation + if frame_stop is None: + frame_stop = len(self.times) - 1 + + frames = np.arange(frame_start, frame_stop, decim) ani = animation.FuncAnimation( - self.fig, self.set_time_idx, len(self.times) - 1, interval=30) + self.fig, self.set_time_idx, frames, interval=interval) - writer = animation.writers['ffmpeg'](fps=30) + writer = animation.writers[writer](fps=fps) ani.save(fname, writer=writer, dpi=dpi) return ani From 70e8881fec138f22d5f3760678fec24a45db9e46 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Tue, 4 Jul 2023 20:02:42 -0400 Subject: [PATCH 11/30] Type checks for input args --- hnn_core/tests/test_viz.py | 47 +++++++++++++++++++++++++++++++++++--- hnn_core/viz.py | 23 ++++++++++++++++++- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index dab93fe44..1fe6b8da1 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -9,8 +9,9 @@ import hnn_core from hnn_core import read_params, jones_2009_model -from hnn_core.viz import plot_cells, plot_dipole, plot_psd, plot_tfr_morlet -from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity +from hnn_core.viz import (plot_cells, plot_dipole, plot_psd, plot_tfr_morlet, + plot_connectivity_matrix, plot_cell_connectivity, + NetworkPlotter) from hnn_core.dipole import simulate_dipole matplotlib.use('agg') @@ -132,7 +133,7 @@ def test_dipole_visualization(): weights_ampa=weights_ampa, synaptic_delays=syn_delays, event_seed=14) - dpls = simulate_dipole(net, tstop=100., n_trials=2) + dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all') fig = dpls[0].plot() # plot the first dipole alone axes = fig.get_axes()[0] dpls[0].copy().smooth(window_len=10).plot(ax=axes) # add smoothed versions @@ -220,4 +221,44 @@ def test_dipole_visualization(): with pytest.raises(ValueError, match="'beta_dist' must be"): net.cell_response.plot_spikes_hist(color={'beta_prox': 'r'}) + # test NetworkPlotter class + with pytest.raises(TypeError, match='xlim must be'): + _ = NetworkPlotter(net, xlim='blah') + with pytest.raises(TypeError, match='ylim must be'): + _ = NetworkPlotter(net, ylim='blah') + with pytest.raises(TypeError, match='zlim must be'): + _ = NetworkPlotter(net, zlim='blah') + with pytest.raises(TypeError, match='elev must be'): + _ = NetworkPlotter(net, elev='blah') + with pytest.raises(TypeError, match='azim must be'): + _ = NetworkPlotter(net, azim='blah') + with pytest.raises(TypeError, match='vmin must be'): + _ = NetworkPlotter(net, vmin='blah') + with pytest.raises(TypeError, match='vmax must be'): + _ = NetworkPlotter(net, vmax='blah') + with pytest.raises(TypeError, match='trial_idx must be'): + _ = NetworkPlotter(net, trial_idx=1.0) + with pytest.raises(TypeError, match='time_idx must be'): + _ = NetworkPlotter(net, time_idx=1.0) + + net_plot = NetworkPlotter(net) + with pytest.raises(TypeError, match='xlim must be'): + net_plot.xlim = 'blah' + with pytest.raises(TypeError, match='ylim must be'): + net_plot.ylim = 'blah' + with pytest.raises(TypeError, match='zlim must be'): + net_plot.zlim = 'blah' + with pytest.raises(TypeError, match='elev must be'): + net_plot.elev = 'blah' + with pytest.raises(TypeError, match='azim must be'): + net_plot.azim = 'blah' + with pytest.raises(TypeError, match='vmin must be'): + net_plot.vmin = 'blah' + with pytest.raises(TypeError, match='vmax must be'): + net_plot.vmax = 'blah' + with pytest.raises(TypeError, match='trial_idx must be'): + net_plot.trial_idx = 1.0 + with pytest.raises(TypeError, match='time_idx must be'): + net_plot.time_idx = 1.0 + plt.close('all') diff --git a/hnn_core/viz.py b/hnn_core/viz.py index d1ed5f728..ab769e667 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1243,7 +1243,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, return ax.get_figure() -class NetworkPlot: +class NetworkPlotter: """Helper class to visualize full morphology of HNN model. @@ -1289,6 +1289,8 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', self.net = net self.times = net.cell_response.times + _validate_type(vmin, (int, float), 'vmin') + _validate_type(vmax, (int, float), 'vmax') self._vmin = vmin self._vmax = vmax @@ -1299,12 +1301,22 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', self.colormap = colormaps[voltage_colormap] # Axes limits and view positions + _validate_type(xlim, tuple, 'xlim') + _validate_type(ylim, tuple, 'ylim') + _validate_type(zlim, tuple, 'zlim') + _validate_type(elev, (int, float), 'elev') + _validate_type(azim, (int, float), 'azim') + self._xlim = xlim self._ylim = ylim self._zlim = zlim self._elev = elev self._azim = azim + # Trial and time indices + _validate_type(trial_idx, int, 'trial_idx') + _validate_type(time_idx, int, 'time_idx') + self._trial_idx = trial_idx self._time_idx = time_idx @@ -1387,6 +1399,7 @@ def xlim(self): @xlim.setter def xlim(self, xlim): + _validate_type(xlim, tuple, 'xlim') self._xlim = xlim self.ax.set_xlim(self._xlim) @@ -1396,6 +1409,7 @@ def ylim(self): @ylim.setter def ylim(self, ylim): + _validate_type(ylim, tuple, 'ylim') self._ylim = ylim self.ax.set_ylim(self._ylim) @@ -1405,6 +1419,7 @@ def zlim(self): @zlim.setter def zlim(self, zlim): + _validate_type(zlim, tuple, 'zlim') self._zlim = zlim self.ax.set_zlim(self._zlim) @@ -1415,6 +1430,7 @@ def elev(self): @elev.setter def elev(self, elev): + _validate_type(elev, (int, float), 'elev') self._elev = elev self.ax.view_init(self._elev, self._azim) @@ -1424,6 +1440,7 @@ def azim(self): @azim.setter def azim(self, azim): + _validate_type(azim, (int, float), 'azim') self._azim = azim self.ax.view_init(self._elev, self._azim) @@ -1434,6 +1451,7 @@ def vmin(self): @vmin.setter def vmin(self, vmin): + _validate_type(vmin, (int, float), 'vmin') self._vmin = vmin self.vsec_array = self.get_voltages() self.color_array = self.colormap(self.vsec_array) @@ -1444,6 +1462,7 @@ def vmax(self): @vmax.setter def vmax(self, vmax): + _validate_type(vmax, (int, float), 'vmax') self._vmax = vmax self.vsec_array = self.get_voltages() self.color_array = self.colormap(self.vsec_array) @@ -1455,6 +1474,7 @@ def trial_idx(self): @trial_idx.setter def trial_idx(self, trial_idx): + _validate_type(trial_idx, int, 'trial_idx') self._trial_idx = trial_idx self.vsec_array = self.get_voltages() self.color_array = self.colormap(self.vsec_array) @@ -1465,6 +1485,7 @@ def time_idx(self): @time_idx.setter def time_idx(self, time_idx): + _validate_type(time_idx, int, 'time_idx') self._time_idx = time_idx self.update_section_voltages(self._time_idx) From 6c21952f24b29a6b1a2bed8951f7a6826cd59999 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Tue, 4 Jul 2023 20:09:29 -0400 Subject: [PATCH 12/30] Better docs --- hnn_core/viz.py | 39 +++++++++++++++++++++++++++++---------- 1 file changed, 29 insertions(+), 10 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index ab769e667..24a4737a3 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1321,7 +1321,7 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', self._time_idx = time_idx # Get voltage data and corresponding colors - self.vsec_array = self.get_voltages() + self.vsec_array = self._get_voltages() self.color_array = self.colormap(self.vsec_array) # Create figure @@ -1331,10 +1331,10 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', self.ax.set_facecolor(self._bg_color) else: self.fig = None - self.init_network_plot() + self._init_network_plot() self._update_axes() - def get_voltages(self): + def _get_voltages(self): vsec_list = list() for cell_type in self.net.cell_types: gid_range = self.net.gid_ranges[cell_type] @@ -1356,7 +1356,7 @@ def update_section_voltages(self, t_idx): for line, color in zip(self.ax.lines, color_list): line.set_color(color) - def init_network_plot(self): + def _init_network_plot(self): for cell_type in self.net.cell_types: gid_range = self.net.gid_ranges[cell_type] for gid_idx, gid in enumerate(gid_range): @@ -1380,13 +1380,32 @@ def _update_axes(self): def export_movie(self, fname, fps=30, dpi=300, decim=10, interval=30, frame_start=0, frame_stop=None, writer='ffmpeg'): + """Export movie of network activity + fname : str + Filename of exported movie + fps : int + Frames per second, default: 30 + dpi : int + Dots per inch, default: 300 + decim : int + Decimation factor for frames, default: 10 + interval : int + Delay between frames, default: 30 + frame_start : int + Index of first frame, default: 0 + frame_stop : int | None + Index of last frame, default: None + If None, entire simulation is animated + writer : str + Movie writer, default: 'ffmpeg' + """ import matplotlib.animation as animation if frame_stop is None: frame_stop = len(self.times) - 1 frames = np.arange(frame_start, frame_stop, decim) ani = animation.FuncAnimation( - self.fig, self.set_time_idx, frames, interval=interval) + self.fig, self._set_time_idx, frames, interval=interval) writer = animation.writers[writer](fps=fps) ani.save(fname, writer=writer, dpi=dpi) @@ -1453,7 +1472,7 @@ def vmin(self): def vmin(self, vmin): _validate_type(vmin, (int, float), 'vmin') self._vmin = vmin - self.vsec_array = self.get_voltages() + self.vsec_array = self._get_voltages() self.color_array = self.colormap(self.vsec_array) @property @@ -1464,7 +1483,7 @@ def vmax(self): def vmax(self, vmax): _validate_type(vmax, (int, float), 'vmax') self._vmax = vmax - self.vsec_array = self.get_voltages() + self.vsec_array = self._get_voltages() self.color_array = self.colormap(self.vsec_array) # Time and trial indices @@ -1476,7 +1495,7 @@ def trial_idx(self): def trial_idx(self, trial_idx): _validate_type(trial_idx, int, 'trial_idx') self._trial_idx = trial_idx - self.vsec_array = self.get_voltages() + self.vsec_array = self._get_voltages() self.color_array = self.colormap(self.vsec_array) @property @@ -1489,8 +1508,8 @@ def time_idx(self, time_idx): self._time_idx = time_idx self.update_section_voltages(self._time_idx) - # Necessary for making animations - def set_time_idx(self, time_idx): + # Callable update function for making animations + def _set_time_idx(self, time_idx): self.time_idx = time_idx # Background color and voltage colormaps From c079edd74f2924e62557ef4c8bde441eb9d53b27 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Fri, 22 Sep 2023 15:51:13 -0400 Subject: [PATCH 13/30] Make time_idx accept np.int --- hnn_core/viz.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 24a4737a3..d30b2bd53 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1504,7 +1504,7 @@ def time_idx(self): @time_idx.setter def time_idx(self, time_idx): - _validate_type(time_idx, int, 'time_idx') + _validate_type(time_idx, (int, np.integer), 'time_idx') self._time_idx = time_idx self.update_section_voltages(self._time_idx) From 2a5fedd578f7e51e269eda510363961a3993f20f Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Fri, 22 Sep 2023 16:14:52 -0400 Subject: [PATCH 14/30] Add self.ax --- hnn_core/viz.py | 1 + 1 file changed, 1 insertion(+) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index d30b2bd53..2cc94eea8 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1330,6 +1330,7 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', self.ax = self.fig.add_subplot(projection='3d') self.ax.set_facecolor(self._bg_color) else: + self.ax = ax self.fig = None self._init_network_plot() self._update_axes() From 326af602702f09612dead697cdc7b7ac3fa748dd Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Sun, 24 Sep 2023 19:59:55 -0400 Subject: [PATCH 15/30] Add more type checks and simulation conditions --- hnn_core/tests/test_viz.py | 63 ++++++++++++++++++++++++++++++++++++++ hnn_core/viz.py | 34 ++++++++++++++++---- 2 files changed, 91 insertions(+), 6 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 1fe6b8da1..0789747d5 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -241,7 +241,38 @@ def test_dipole_visualization(): with pytest.raises(TypeError, match='time_idx must be'): _ = NetworkPlotter(net, time_idx=1.0) + net = jones_2009_model(params) net_plot = NetworkPlotter(net) + + assert net_plot.vsec_array.shape == (159, 1) + assert net_plot.color_array.shape == (159, 1, 4) + assert net_plot._vsec_recorded is False + + # Errors if vsec isn't recorded + with pytest.raises(RuntimeError, match='Network must be simulated'): + net_plot.export_movie('demo.gif', dpi=200) + + # Errors if vsec isn't recorded with record_vsec='all' + _ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='soma') + net_plot = NetworkPlotter(net) + + assert net_plot.vsec_array.shape == (159, 1) + assert net_plot.color_array.shape == (159, 1, 4) + assert net_plot._vsec_recorded is False + + with pytest.raises(RuntimeError, match='Network must be simulated'): + net_plot.export_movie('demo.gif', dpi=200) + + # Simulate with record_vsec='all' to test voltage plotting + net = jones_2009_model(params) + _ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='all') + net_plot = NetworkPlotter(net) + + assert net_plot.vsec_array.shape == (159, 21) + assert net_plot.color_array.shape == (159, 21, 4) + assert net_plot._vsec_recorded is True + + # Type check errors with pytest.raises(TypeError, match='xlim must be'): net_plot.xlim = 'blah' with pytest.raises(TypeError, match='ylim must be'): @@ -261,4 +292,36 @@ def test_dipole_visualization(): with pytest.raises(TypeError, match='time_idx must be'): net_plot.time_idx = 1.0 + # Check that the setters work + net_plot.xlim = (-100, 100) + net_plot.ylim = (-100, 100) + net_plot.zlim = (-100, 100) + net_plot.elev = 10 + net_plot.azim = 10 + net_plot.vmin = 0 + net_plot.vmax = 100 + net_plot.trial_idx = 0 + net_plot.time_idx = 5 + net_plot.bgcolor = 'white' + net_plot.voltage_colormap = 'jet' + + # Check that the getters work + assert net_plot.xlim == (-100, 100) + assert net_plot.ylim == (-100, 100) + assert net_plot.zlim == (-100, 100) + assert net_plot.elev == 10 + assert net_plot.azim == 10 + assert net_plot.vmin == 0 + assert net_plot.vmax == 100 + assert net_plot.trial_idx == 0 + assert net_plot.time_idx == 5 + + assert net_plot.bgcolor == 'white' + assert net_plot.fig.get_facecolor() == (1.0, 1.0, 1.0, 1.0) + + assert net_plot.voltage_colormap == 'jet' + + # Test animation export and voltage plotting + net_plot.export_movie('demo.gif', dpi=200, decim=100, writer='pillow') + plt.close('all') diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 2cc94eea8..5ef4a4929 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1287,7 +1287,20 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', import matplotlib.pyplot as plt from matplotlib import colormaps self.net = net - self.times = net.cell_response.times + + # Check if network simulated + if net.cell_response is not None: + self.times = net.cell_response.times + + # Check if voltage recorded + if net._params['record_vsec'] == 'all': + self._vsec_recorded = True + else: + self._vsec_recorded = False + else: + self._is_simulated = False + self._vsec_recorded = False + self.times = None _validate_type(vmin, (int, float), 'vmin') _validate_type(vmax, (int, float), 'vmax') @@ -1340,19 +1353,24 @@ def _get_voltages(self): for cell_type in self.net.cell_types: gid_range = self.net.gid_ranges[cell_type] for gid in gid_range: - cell = self.net.cell_types[cell_type] - for sec_name in cell.sections.keys(): - vsec = np.array(self.net.cell_response.vsec[ - self.trial_idx][gid][sec_name]) - vsec_list.append(vsec) + if self._vsec_recorded is True: + vsec = np.array(self.net.cell_response.vsec[ + self.trial_idx][gid][sec_name]) + vsec_list.append(vsec) + else: # Populate with zeros if no voltage recording + vsec_list.append([0.0]) vsec_array = np.vstack(vsec_list) vsec_array = (vsec_array - self.vmin) / (self.vmax - self.vmin) return vsec_array def update_section_voltages(self, t_idx): + if not self._vsec_recorded: + raise RuntimeError("Network must be simulated with" + "`simulate_dipole(record_vsec='all')` before" + "plotting voltages.") color_list = self.color_array[:, t_idx] for line, color in zip(self.ax.lines, color_list): line.set_color(color) @@ -1401,6 +1419,10 @@ def export_movie(self, fname, fps=30, dpi=300, decim=10, Movie writer, default: 'ffmpeg' """ import matplotlib.animation as animation + + if not self._vsec_recorded: + raise RuntimeError('Network must be simulated before' + 'plotting voltages.') if frame_stop is None: frame_stop = len(self.times) - 1 From 3e37322eb07d6fa04136b44d3d4694eb5dd3aba2 Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Mon, 25 Sep 2023 13:07:43 -0400 Subject: [PATCH 16/30] Make update_voltages private --- hnn_core/tests/test_viz.py | 7 ++++--- hnn_core/viz.py | 15 +++++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 0789747d5..d6aa76d9c 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -265,7 +265,8 @@ def test_dipole_visualization(): # Simulate with record_vsec='all' to test voltage plotting net = jones_2009_model(params) - _ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='all') + _ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=2, + record_vsec='all') net_plot = NetworkPlotter(net) assert net_plot.vsec_array.shape == (159, 21) @@ -300,7 +301,7 @@ def test_dipole_visualization(): net_plot.azim = 10 net_plot.vmin = 0 net_plot.vmax = 100 - net_plot.trial_idx = 0 + net_plot.trial_idx = 1 net_plot.time_idx = 5 net_plot.bgcolor = 'white' net_plot.voltage_colormap = 'jet' @@ -313,7 +314,7 @@ def test_dipole_visualization(): assert net_plot.azim == 10 assert net_plot.vmin == 0 assert net_plot.vmax == 100 - assert net_plot.trial_idx == 0 + assert net_plot.trial_idx == 1 assert net_plot.time_idx == 5 assert net_plot.bgcolor == 'white' diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 5ef4a4929..22b44b275 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1366,7 +1366,7 @@ def _get_voltages(self): vsec_array = (vsec_array - self.vmin) / (self.vmax - self.vmin) return vsec_array - def update_section_voltages(self, t_idx): + def _update_section_voltages(self, t_idx): if not self._vsec_recorded: raise RuntimeError("Network must be simulated with" "`simulate_dipole(record_vsec='all')` before" @@ -1398,7 +1398,7 @@ def _update_axes(self): def export_movie(self, fname, fps=30, dpi=300, decim=10, interval=30, frame_start=0, frame_stop=None, - writer='ffmpeg'): + writer='pillow'): """Export movie of network activity fname : str Filename of exported movie @@ -1416,13 +1416,16 @@ def export_movie(self, fname, fps=30, dpi=300, decim=10, Index of last frame, default: None If None, entire simulation is animated writer : str - Movie writer, default: 'ffmpeg' + Movie writer, default: 'pillow'. + Alternative movie writers can be found at + https://matplotlib.org/stable/api/animation_api.html """ import matplotlib.animation as animation if not self._vsec_recorded: - raise RuntimeError('Network must be simulated before' - 'plotting voltages.') + raise RuntimeError("Network must be simulated with" + "`simulate_dipole(record_vsec='all')` before" + "plotting voltages.") if frame_stop is None: frame_stop = len(self.times) - 1 @@ -1529,7 +1532,7 @@ def time_idx(self): def time_idx(self, time_idx): _validate_type(time_idx, (int, np.integer), 'time_idx') self._time_idx = time_idx - self.update_section_voltages(self._time_idx) + self._update_section_voltages(self._time_idx) # Callable update function for making animations def _set_time_idx(self, time_idx): From 709788ddb450dc3df59ca7f1a32fdfd1b4904eef Mon Sep 17 00:00:00 2001 From: Nick Tolley Date: Mon, 25 Sep 2023 13:48:06 -0400 Subject: [PATCH 17/30] Update example script --- examples/howto/plot_hnn_animation.py | 66 ++++++++++++++++++++-------- hnn_core/viz.py | 1 - 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/examples/howto/plot_hnn_animation.py b/examples/howto/plot_hnn_animation.py index 1b13f0402..4cd9aae16 100644 --- a/examples/howto/plot_hnn_animation.py +++ b/examples/howto/plot_hnn_animation.py @@ -1,6 +1,6 @@ """ ================================ -XX. Modifying local connectivity +06. Animating HNN simulations ================================ This example demonstrates how to animate HNN simulations @@ -10,28 +10,58 @@ ############################################################################### -from hnn_core import jones_2009_model, simulate_dipole +# First, we'll import the necessary modules for instantiating a network and +# running a simulation that we would like to animate. +import os.path as op +import hnn_core +from hnn_core import jones_2009_model, simulate_dipole, read_params from hnn_core.network_models import add_erp_drives_to_jones_model -from hnn_core.viz import NetworkPlot -net = jones_2009_model() -net.set_cell_positions(inplane_distance=300) -add_erp_drives_to_jones_model(net) -dpl = simulate_dipole(net, dt=0.5, tstop=170, record_vsec='all') +############################################################################### +# We begin by instantiating the network. For this example, we will reduce the +# number of cells in the network to speed up the simulations. +hnn_core_root = op.dirname(hnn_core.__file__) +params_fname = op.join(hnn_core_root, 'param', 'default.json') +params = read_params(params_fname) +params.update({'N_pyr_x': 3, 'N_pyr_y': 3}) +net = jones_2009_model(params) -net_plot = NetworkPlot(net) +# Note that we move the cells further apart to allow better visualization of +# the network (default inplane_distance=1.0 µm). +net.set_cell_positions(inplane_distance=300) ############################################################################### -from ipywidgets import interact, IntSlider +# The :class:`hnn_core.viz.NetworkPlotter` class can be used to visualize +# the 3D structure of the network. +from hnn_core.viz import NetworkPlotter -def update_plot(t_idx, elev, azim): - net_plot.update_section_voltages(t_idx) - net_plot.elev = elev - net_plot.azim = azim - return net_plot.fig +net_plot = NetworkPlotter(net) +net_plot.fig -time_slider = IntSlider(min=0, max=len(net_plot.times), value=1, continuous_update=False) -elev_slider = IntSlider(min=-100, max=100, value=10, continuous_update=False) -azim_slider = IntSlider(min=-100, max=100, value=-100, continuous_update=False) +############################################################################### +# We can also visualize the network from another angle by adjusting the +# azimuth and elevation parameters. +net_plot.azim = 45 +net_plot.elev = 40 +net_plot.fig -interact(update_plot, t_idx=time_slider, elev=elev_slider, azim=azim_slider) +############################################################################### +# Next we add event related potential (ERP) producing drives to the network +# and run the simulation (see +# :ref:`evoked example ` +# for more details). +# To visualize the membrane potential of cells in the +# network, we need use `simulate_dipole(..., record_vsec='all')` which turns +# on the recording of voltages in all sections of all cells in the network. +add_erp_drives_to_jones_model(net) +dpl = simulate_dipole(net, tstop=170, record_vsec='all') +net_plot = NetworkPlotter(net) # Reinitialize plotter with simulated network + +############################################################################### +# Finally, we can animate the simulation using the `export_movie()` method. We +# can adjust the xyz limits of the plot to better visualize the network. +net_plot.xlim = (400, 1600) +net_plot.ylim = (400, 1600) +net_plot.zlim = (-500, 1600) +net_plot.azim = 225 +net_plot.export_movie('animation_demo.gif', dpi=100, fps=30, interval=100) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 22b44b275..8a35b6782 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1298,7 +1298,6 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', else: self._vsec_recorded = False else: - self._is_simulated = False self._vsec_recorded = False self.times = None From 6b58aea9f576acf74ec760ba057389cc553979dd Mon Sep 17 00:00:00 2001 From: Nicholas Tolley <55253912+ntolley@users.noreply.github.com> Date: Mon, 20 Nov 2023 13:38:39 -0500 Subject: [PATCH 18/30] formatting Co-authored-by: Mainak Jas --- examples/howto/plot_hnn_animation.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/howto/plot_hnn_animation.py b/examples/howto/plot_hnn_animation.py index 4cd9aae16..dc1ae8f88 100644 --- a/examples/howto/plot_hnn_animation.py +++ b/examples/howto/plot_hnn_animation.py @@ -13,6 +13,7 @@ # First, we'll import the necessary modules for instantiating a network and # running a simulation that we would like to animate. import os.path as op + import hnn_core from hnn_core import jones_2009_model, simulate_dipole, read_params from hnn_core.network_models import add_erp_drives_to_jones_model From bb568cff01e5462a4d862425d34a1c342417fd2e Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Mon, 20 Nov 2023 18:21:02 -0500 Subject: [PATCH 19/30] respond to reviews and add colorbar functionality --- hnn_core/tests/test_viz.py | 11 ++++++ hnn_core/viz.py | 71 +++++++++++++++++++++++++++++++------- 2 files changed, 70 insertions(+), 12 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index d6aa76d9c..671163225 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -3,6 +3,8 @@ import matplotlib from matplotlib import backend_bases import matplotlib.pyplot as plt +from matplotlib.colorbar import Colorbar + import numpy as np from numpy.testing import assert_allclose import pytest @@ -240,6 +242,8 @@ def test_dipole_visualization(): _ = NetworkPlotter(net, trial_idx=1.0) with pytest.raises(TypeError, match='time_idx must be'): _ = NetworkPlotter(net, time_idx=1.0) + with pytest.raises(TypeError, match='colorbar must be'): + _ = NetworkPlotter(net, colorbar='blah') net = jones_2009_model(params) net_plot = NetworkPlotter(net) @@ -272,6 +276,7 @@ def test_dipole_visualization(): assert net_plot.vsec_array.shape == (159, 21) assert net_plot.color_array.shape == (159, 21, 4) assert net_plot._vsec_recorded is True + assert isinstance(net_plot._cbar, Colorbar) # Type check errors with pytest.raises(TypeError, match='xlim must be'): @@ -292,6 +297,8 @@ def test_dipole_visualization(): net_plot.trial_idx = 1.0 with pytest.raises(TypeError, match='time_idx must be'): net_plot.time_idx = 1.0 + with pytest.raises(TypeError, match='colorbar must be'): + net_plot.colorbar = 'blah' # Check that the setters work net_plot.xlim = (-100, 100) @@ -306,6 +313,10 @@ def test_dipole_visualization(): net_plot.bgcolor = 'white' net_plot.voltage_colormap = 'jet' + net_plot.colorbar = False + # check later that net._cbar is None + assert net_plot._cbar is None + # Check that the getters work assert net_plot.xlim == (-100, 100) assert net_plot.ylim == (-100, 100) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 8a35b6782..168c792d6 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1211,7 +1211,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, ax : instance of matplotlib figure | None The matplotlib axis. colorbar : bool - If the colorbar is presented. + If True (default), adjust figure to include colorbar. contact_labels : list Labels associated with the contacts to plot. Passed as-is to :func:`~matplotlib.axes.Axes.set_yticklabels`. @@ -1244,8 +1244,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True, class NetworkPlotter: - """Helper class to visualize full - morphology of HNN model. + """Helper class to visualize full morphology of HNN model. Parameters ---------- @@ -1262,6 +1261,8 @@ class NetworkPlotter: Default: 50 mV bg_color : str Background color of ax. Default: 'black' + colorbar : bool + If True (default), adjust figure to include colorbar. voltage_colormap : str Colormap used for plotting voltages Default: 'viridis' @@ -1281,7 +1282,7 @@ class NetworkPlotter: Index of time point plotted. Default: 0 """ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', - voltage_colormap='viridis', elev=10, azim=-500, + colorbar=True, voltage_colormap='viridis', elev=10, azim=-500, xlim=(-200, 3100), ylim=(-200, 3100), zlim=(-300, 2200), trial_idx=0, time_idx=0): import matplotlib.pyplot as plt @@ -1309,8 +1310,8 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', self._bg_color = bg_color self._voltage_colormap = voltage_colormap - self.colormaps = colormaps # Saved for voltage_colormap update method - self.colormap = colormaps[voltage_colormap] + self._colormaps = colormaps # Saved for voltage_colormap update method + self._colormap = colormaps[voltage_colormap] # Axes limits and view positions _validate_type(xlim, tuple, 'xlim') @@ -1334,7 +1335,7 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', # Get voltage data and corresponding colors self.vsec_array = self._get_voltages() - self.color_array = self.colormap(self.vsec_array) + self.color_array = self._colormap(self.vsec_array) # Create figure if ax is None: @@ -1347,6 +1348,13 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', self._init_network_plot() self._update_axes() + _validate_type(colorbar, bool, 'colorbar') + self._colorbar = colorbar + if self._colorbar: + self._update_colorbar() + else: + self._cbar = None + def _get_voltages(self): vsec_list = list() for cell_type in self.net.cell_types: @@ -1395,10 +1403,23 @@ def _update_axes(self): self.ax.view_init(self._elev, self._azim) + def _update_colorbar(self): + import matplotlib.pyplot as plt + import matplotlib.colors as mc + + fig = self.ax.get_figure() + sm = plt.cm.ScalarMappable( + cmap=self.voltage_colormap, + norm=mc.Normalize(vmin=self.vmin, vmax=self.vmax)) + self._cbar = fig.colorbar(sm, ax=self.ax) + def export_movie(self, fname, fps=30, dpi=300, decim=10, interval=30, frame_start=0, frame_stop=None, writer='pillow'): """Export movie of network activity + + Parameters + ---------- fname : str Filename of exported movie fps : int @@ -1498,7 +1519,10 @@ def vmin(self, vmin): _validate_type(vmin, (int, float), 'vmin') self._vmin = vmin self.vsec_array = self._get_voltages() - self.color_array = self.colormap(self.vsec_array) + self.color_array = self._colormap(self.vsec_array) + if self._colorbar: + self._cbar.remove() + self._update_colorbar() @property def vmax(self): @@ -1509,7 +1533,10 @@ def vmax(self, vmax): _validate_type(vmax, (int, float), 'vmax') self._vmax = vmax self.vsec_array = self._get_voltages() - self.color_array = self.colormap(self.vsec_array) + self.color_array = self._colormap(self.vsec_array) + if self._colorbar: + self._cbar.remove() + self._update_colorbar() # Time and trial indices @property @@ -1521,7 +1548,7 @@ def trial_idx(self, trial_idx): _validate_type(trial_idx, int, 'trial_idx') self._trial_idx = trial_idx self.vsec_array = self._get_voltages() - self.color_array = self.colormap(self.vsec_array) + self.color_array = self._colormap(self.vsec_array) @property def time_idx(self): @@ -1554,5 +1581,25 @@ def voltage_colormap(self): @voltage_colormap.setter def voltage_colormap(self, voltage_colormap): self._voltage_colormap = voltage_colormap - self.colormap = self.colormaps[self._voltage_colormap] - self.color_array = self.colormap(self.vsec_array) + self._colormap = self._colormaps[self._voltage_colormap] + self.color_array = self._colormap(self.vsec_array) + if self._colorbar: + self._cbar.remove() + self._update_colorbar() + + @property + def colorbar(self): + return self._colorbar + + @colorbar.setter + def colorbar(self, colorbar): + _validate_type(colorbar, bool, 'colorbar') + self._colorbar = colorbar + if self._colorbar: + # Remove old colorbar if already exists + if self._cbar is not None: + self._cbar.remove() + self._update_colorbar() + else: + self._cbar.remove() + self._cbar = None From 1a5697befe709c8897d9d507c4d1b28118d2a384 Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Mon, 20 Nov 2023 18:22:58 -0500 Subject: [PATCH 20/30] add to api.rst --- doc/api.rst | 1 + 1 file changed, 1 insertion(+) diff --git a/doc/api.rst b/doc/api.rst index 7b13e4af8..2b550ad2a 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -82,6 +82,7 @@ Visualization (:py:mod:`hnn_core.viz`): plot_connectivity_matrix plot_laminar_lfp plot_laminar_csd + NetworkPlotter Parallel backends (:py:mod:`hnn_core.parallel_backends`): --------------------------------------------------------- From 68fc6c6115cc9fab2a20d52b144fa1dda6d11922 Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 10:36:15 -0400 Subject: [PATCH 21/30] Fix test with smaller net --- hnn_core/tests/test_viz.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 671163225..698ad0d52 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -245,7 +245,7 @@ def test_dipole_visualization(): with pytest.raises(TypeError, match='colorbar must be'): _ = NetworkPlotter(net, colorbar='blah') - net = jones_2009_model(params) + net = jones_2009_model(params, mesh_shape=(3, 3)) net_plot = NetworkPlotter(net) assert net_plot.vsec_array.shape == (159, 1) @@ -268,7 +268,7 @@ def test_dipole_visualization(): net_plot.export_movie('demo.gif', dpi=200) # Simulate with record_vsec='all' to test voltage plotting - net = jones_2009_model(params) + net = jones_2009_model(params, mesh_shape=(3, 3)) _ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=2, record_vsec='all') net_plot = NetworkPlotter(net) From f6ccccc164e5bec3be6d020deee090e7e2927fad Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 10:43:52 -0400 Subject: [PATCH 22/30] Use mesh_shape in example notebook --- examples/howto/plot_hnn_animation.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/examples/howto/plot_hnn_animation.py b/examples/howto/plot_hnn_animation.py index dc1ae8f88..04a8e7704 100644 --- a/examples/howto/plot_hnn_animation.py +++ b/examples/howto/plot_hnn_animation.py @@ -21,11 +21,7 @@ ############################################################################### # We begin by instantiating the network. For this example, we will reduce the # number of cells in the network to speed up the simulations. -hnn_core_root = op.dirname(hnn_core.__file__) -params_fname = op.join(hnn_core_root, 'param', 'default.json') -params = read_params(params_fname) -params.update({'N_pyr_x': 3, 'N_pyr_y': 3}) -net = jones_2009_model(params) +net = jones_2009_model(add_drives_from_params=True, mesh_shape=(3, 3)) # Note that we move the cells further apart to allow better visualization of # the network (default inplane_distance=1.0 µm). From dcd55aaaf944e862f423731da9ff4aa7be781b6b Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 10:51:37 -0400 Subject: [PATCH 23/30] update plot_morphology docstring --- examples/howto/plot_hnn_animation.py | 2 +- hnn_core/cell.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/examples/howto/plot_hnn_animation.py b/examples/howto/plot_hnn_animation.py index 04a8e7704..858ee2e99 100644 --- a/examples/howto/plot_hnn_animation.py +++ b/examples/howto/plot_hnn_animation.py @@ -21,7 +21,7 @@ ############################################################################### # We begin by instantiating the network. For this example, we will reduce the # number of cells in the network to speed up the simulations. -net = jones_2009_model(add_drives_from_params=True, mesh_shape=(3, 3)) +net = jones_2009_model(mesh_shape=(3, 3)) # Note that we move the cells further apart to allow better visualization of # the network (default inplane_distance=1.0 µm). diff --git a/hnn_core/cell.py b/hnn_core/cell.py index 3c93682d4..110034b82 100644 --- a/hnn_core/cell.py +++ b/hnn_core/cell.py @@ -882,6 +882,12 @@ def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0), pos : tuple of int or float | None Position of cell soma. Must be a tuple of 3 elements for the (x, y, z) position of the soma in 3D space. Default: (0, 0, 0) + xlim : tuple of int | tuple of float + x limits of plot window. Default (-250, 150) + ylim : tuple of int | tuple of float + y limits of plot window. Default (-100, 100) + zlim : tuple of int | tuple of float + z limits of plot window. Default (-100, 1200) show : bool If True, show the plot From dc6e2f084f60de79bcb1700f07585b61dc197843 Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 12:01:52 -0400 Subject: [PATCH 24/30] refactor network_plotter tests --- hnn_core/tests/test_viz.py | 66 +++++++++++++++++++++++++++----------- hnn_core/viz.py | 9 ++++++ 2 files changed, 56 insertions(+), 19 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index 698ad0d52..fc5244a0a 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -19,6 +19,15 @@ matplotlib.use('agg') +@pytest.fixture +def setup_net(): + hnn_core_root = op.dirname(hnn_core.__file__) + params_fname = op.join(hnn_core_root, 'param', 'default.json') + params = read_params(params_fname) + net = jones_2009_model(params, mesh_shape=(3, 3)) + + return net + def _fake_click(fig, ax, point, button=1): """Fake a click at a point within axes.""" x, y = ax.transData.transform_point(point) @@ -29,12 +38,9 @@ def _fake_click(fig, ax, point, button=1): fig.canvas.callbacks.process('button_press_event', button_press_event) -def test_network_visualization(): +def test_network_visualization(setup_net): """Test network visualisations.""" - hnn_core_root = op.dirname(hnn_core.__file__) - params_fname = op.join(hnn_core_root, 'param', 'default.json') - params = read_params(params_fname) - net = jones_2009_model(params, mesh_shape=(3, 3)) + net = setup_net plot_cells(net) ax = net.cell_types['L2_pyramidal'].plot_morphology() assert len(ax.lines) == 8 @@ -114,12 +120,9 @@ def test_network_visualization(): plt.close('all') -def test_dipole_visualization(): +def test_dipole_visualization(setup_net): """Test dipole visualisations.""" - hnn_core_root = op.dirname(hnn_core.__file__) - params_fname = op.join(hnn_core_root, 'param', 'default.json') - params = read_params(params_fname) - net = jones_2009_model(params, mesh_shape=(3, 3)) + net = setup_net weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5} syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.} @@ -223,6 +226,10 @@ def test_dipole_visualization(): with pytest.raises(ValueError, match="'beta_dist' must be"): net.cell_response.plot_spikes_hist(color={'beta_prox': 'r'}) + +def test_network_plotter_init(setup_net): + """Test init keywords of NetworkPlotter class.""" + net = setup_net # test NetworkPlotter class with pytest.raises(TypeError, match='xlim must be'): _ = NetworkPlotter(net, xlim='blah') @@ -245,13 +252,17 @@ def test_dipole_visualization(): with pytest.raises(TypeError, match='colorbar must be'): _ = NetworkPlotter(net, colorbar='blah') - net = jones_2009_model(params, mesh_shape=(3, 3)) net_plot = NetworkPlotter(net) assert net_plot.vsec_array.shape == (159, 1) assert net_plot.color_array.shape == (159, 1, 4) assert net_plot._vsec_recorded is False + +def test_network_plotter_simulation(setup_net): + """Test NetworkPlotter class simulation warnings.""" + net = setup_net + net_plot = NetworkPlotter(net) # Errors if vsec isn't recorded with pytest.raises(RuntimeError, match='Network must be simulated'): net_plot.export_movie('demo.gif', dpi=200) @@ -267,17 +278,23 @@ def test_dipole_visualization(): with pytest.raises(RuntimeError, match='Network must be simulated'): net_plot.export_movie('demo.gif', dpi=200) - # Simulate with record_vsec='all' to test voltage plotting - net = jones_2009_model(params, mesh_shape=(3, 3)) - _ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=2, - record_vsec='all') + net = setup_net + _ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='all') net_plot = NetworkPlotter(net) + # setter/getter test for time_idx + net_plot.time_idx = 5 + assert net_plot.time_idx == 5 assert net_plot.vsec_array.shape == (159, 21) assert net_plot.color_array.shape == (159, 21, 4) assert net_plot._vsec_recorded is True assert isinstance(net_plot._cbar, Colorbar) + +def test_network_plotter_setter(setup_net): + """Test NetworkPlotter class setters and getters.""" + net = setup_net + net_plot = NetworkPlotter(net) # Type check errors with pytest.raises(TypeError, match='xlim must be'): net_plot.xlim = 'blah' @@ -308,15 +325,19 @@ def test_dipole_visualization(): net_plot.azim = 10 net_plot.vmin = 0 net_plot.vmax = 100 - net_plot.trial_idx = 1 - net_plot.time_idx = 5 net_plot.bgcolor = 'white' net_plot.voltage_colormap = 'jet' net_plot.colorbar = False - # check later that net._cbar is None assert net_plot._cbar is None + # time_idx setter should raise an error if network is not simulated + with pytest.raises(RuntimeError, match='Network must be simulated'): + net_plot.time_idx = 5 + + with pytest.raises(RuntimeError, match='Network must be simulated'): + net_plot.trial_idx = 1 + # Check that the getters work assert net_plot.xlim == (-100, 100) assert net_plot.ylim == (-100, 100) @@ -326,13 +347,20 @@ def test_dipole_visualization(): assert net_plot.vmin == 0 assert net_plot.vmax == 100 assert net_plot.trial_idx == 1 - assert net_plot.time_idx == 5 assert net_plot.bgcolor == 'white' assert net_plot.fig.get_facecolor() == (1.0, 1.0, 1.0, 1.0) assert net_plot.voltage_colormap == 'jet' + +def test_network_plotter_export(setup_net): + """Test NetworkPlotter class export methods.""" + net = setup_net + _ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=1, + record_vsec='all') + net_plot = NetworkPlotter(net) + # Test animation export and voltage plotting net_plot.export_movie('demo.gif', dpi=200, decim=100, writer='pillow') diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 168c792d6..65d1b8dfc 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -876,6 +876,14 @@ def plot_cell_morphology( pos : tuple of int or float | None Position of cell soma. Must be a tuple of 3 elements for the (x, y, z) position of the soma in 3D space. Default: (0, 0, 0) + xlim : tuple of int | tuple of float + x limits of plot window. Default (-250, 150) + ylim : tuple of int | tuple of float + y limits of plot window. Default (-100, 100) + zlim : tuple of int | tuple of float + z limits of plot window. Default (-100, 1200) + show : bool + If True, show the plot Returns ------- @@ -1549,6 +1557,7 @@ def trial_idx(self, trial_idx): self._trial_idx = trial_idx self.vsec_array = self._get_voltages() self.color_array = self._colormap(self.vsec_array) + self._update_section_voltages(self._time_idx) @property def time_idx(self): From 303838d4a6bf8d8b4b2f430d1f7cad5c2c2b2f29 Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 12:33:52 -0400 Subject: [PATCH 25/30] better test readability --- hnn_core/tests/test_viz.py | 75 +++++++++++++------------------------- 1 file changed, 25 insertions(+), 50 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index fc5244a0a..a671dda50 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -28,6 +28,7 @@ def setup_net(): return net + def _fake_click(fig, ax, point, button=1): """Fake a click at a point within axes.""" x, y = ax.transData.transform_point(point) @@ -225,6 +226,7 @@ def test_dipole_visualization(setup_net): 'beta_dist': 'g'}) with pytest.raises(ValueError, match="'beta_dist' must be"): net.cell_response.plot_spikes_hist(color={'beta_prox': 'r'}) + plt.close('all') def test_network_plotter_init(setup_net): @@ -257,6 +259,7 @@ def test_network_plotter_init(setup_net): assert net_plot.vsec_array.shape == (159, 1) assert net_plot.color_array.shape == (159, 1, 4) assert net_plot._vsec_recorded is False + plt.close('all') def test_network_plotter_simulation(setup_net): @@ -279,16 +282,19 @@ def test_network_plotter_simulation(setup_net): net_plot.export_movie('demo.gif', dpi=200) net = setup_net - _ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='all') + _ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='all', n_trials=2) net_plot = NetworkPlotter(net) - # setter/getter test for time_idx + # setter/getter test for time_idx and trial_idx net_plot.time_idx = 5 assert net_plot.time_idx == 5 + net_plot.trial_idx = 1 + assert net_plot.trial_idx == 1 assert net_plot.vsec_array.shape == (159, 21) assert net_plot.color_array.shape == (159, 21, 4) assert net_plot._vsec_recorded is True assert isinstance(net_plot._cbar, Colorbar) + plt.close('all') def test_network_plotter_setter(setup_net): @@ -296,40 +302,23 @@ def test_network_plotter_setter(setup_net): net = setup_net net_plot = NetworkPlotter(net) # Type check errors - with pytest.raises(TypeError, match='xlim must be'): - net_plot.xlim = 'blah' - with pytest.raises(TypeError, match='ylim must be'): - net_plot.ylim = 'blah' - with pytest.raises(TypeError, match='zlim must be'): - net_plot.zlim = 'blah' - with pytest.raises(TypeError, match='elev must be'): - net_plot.elev = 'blah' - with pytest.raises(TypeError, match='azim must be'): - net_plot.azim = 'blah' - with pytest.raises(TypeError, match='vmin must be'): - net_plot.vmin = 'blah' - with pytest.raises(TypeError, match='vmax must be'): - net_plot.vmax = 'blah' - with pytest.raises(TypeError, match='trial_idx must be'): - net_plot.trial_idx = 1.0 - with pytest.raises(TypeError, match='time_idx must be'): - net_plot.time_idx = 1.0 - with pytest.raises(TypeError, match='colorbar must be'): - net_plot.colorbar = 'blah' - - # Check that the setters work - net_plot.xlim = (-100, 100) - net_plot.ylim = (-100, 100) - net_plot.zlim = (-100, 100) - net_plot.elev = 10 - net_plot.azim = 10 - net_plot.vmin = 0 - net_plot.vmax = 100 - net_plot.bgcolor = 'white' - net_plot.voltage_colormap = 'jet' - - net_plot.colorbar = False + args = ['xlim', 'ylim', 'zlim', 'elev', 'azim', 'vmin', 'vmax', + 'trial_idx', 'time_idx', 'colorbar'] + for arg in args: + with pytest.raises(TypeError, match=f'{arg} must be'): + setattr(net_plot, arg, 'blah') + + # Check that the setters and getters work + arg_dict = {'xlim': (-100, 100), 'ylim': (-100, 100), 'zlim': (-100, 100), + 'elev': 10, 'azim': 10, 'vmin': 0, 'vmax': 100, + 'bgcolor': 'white', 'voltage_colormap': 'jet', + 'colorbar': False} + for arg, val in arg_dict.items(): + setattr(net_plot, arg, val) + assert getattr(net_plot, arg) == val + assert net_plot._cbar is None + assert net_plot.fig.get_facecolor() == (1.0, 1.0, 1.0, 1.0) # time_idx setter should raise an error if network is not simulated with pytest.raises(RuntimeError, match='Network must be simulated'): @@ -337,21 +326,7 @@ def test_network_plotter_setter(setup_net): with pytest.raises(RuntimeError, match='Network must be simulated'): net_plot.trial_idx = 1 - - # Check that the getters work - assert net_plot.xlim == (-100, 100) - assert net_plot.ylim == (-100, 100) - assert net_plot.zlim == (-100, 100) - assert net_plot.elev == 10 - assert net_plot.azim == 10 - assert net_plot.vmin == 0 - assert net_plot.vmax == 100 - assert net_plot.trial_idx == 1 - - assert net_plot.bgcolor == 'white' - assert net_plot.fig.get_facecolor() == (1.0, 1.0, 1.0, 1.0) - - assert net_plot.voltage_colormap == 'jet' + plt.close('all') def test_network_plotter_export(setup_net): From 7d4939148729a68398e254153e360bd356f4386c Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 14:08:38 -0400 Subject: [PATCH 26/30] refactor init --- hnn_core/viz.py | 101 +++++++++++++++++++++++++++--------------------- 1 file changed, 58 insertions(+), 43 deletions(-) diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 65d1b8dfc..312425f6e 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1293,76 +1293,83 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black', colorbar=True, voltage_colormap='viridis', elev=10, azim=-500, xlim=(-200, 3100), ylim=(-200, 3100), zlim=(-300, 2200), trial_idx=0, time_idx=0): - import matplotlib.pyplot as plt from matplotlib import colormaps - self.net = net - - # Check if network simulated - if net.cell_response is not None: - self.times = net.cell_response.times - # Check if voltage recorded - if net._params['record_vsec'] == 'all': - self._vsec_recorded = True - else: - self._vsec_recorded = False - else: - self._vsec_recorded = False - self.times = None + self._validate_parameters(vmin, vmax, bg_color, voltage_colormap, + colorbar, elev, azim, xlim, ylim, zlim, + trial_idx, time_idx) - _validate_type(vmin, (int, float), 'vmin') - _validate_type(vmax, (int, float), 'vmax') + # Set init arguments + self.net = net + self.ax = ax self._vmin = vmin self._vmax = vmax - self._bg_color = bg_color + self._colorbar = colorbar self._voltage_colormap = voltage_colormap + self._colormaps = colormaps + self._xlim = xlim + self._ylim = ylim + self._zlim = zlim + self._elev = elev + self._azim = azim + self._trial_idx = trial_idx + self._time_idx = time_idx - self._colormaps = colormaps # Saved for voltage_colormap update method + # Check if Network object is simulated + self.times, self._vsec_recorded = self._check_network_simulation() + + # Initialize plots and colormap + self.fig = None self._colormap = colormaps[voltage_colormap] + self.vsec_array = self._get_voltages() + self.color_array = self._colormap(self.vsec_array) - # Axes limits and view positions + self._initialize_plots() + if self._colorbar: + self._update_colorbar() + else: + self._cbar = None + + def _validate_parameters(self, vmin, vmax, bg_color, voltage_colormap, + colorbar, elev, azim, xlim, ylim, zlim, trial_idx, + time_idx): + _validate_type(vmin, (int, float), 'vmin') + _validate_type(vmax, (int, float), 'vmax') + _validate_type(bg_color, str, 'bg_color') + _validate_type(voltage_colormap, str, 'voltage_colormap') + _validate_type(colorbar, bool, 'colorbar') _validate_type(xlim, tuple, 'xlim') _validate_type(ylim, tuple, 'ylim') _validate_type(zlim, tuple, 'zlim') _validate_type(elev, (int, float), 'elev') _validate_type(azim, (int, float), 'azim') - - self._xlim = xlim - self._ylim = ylim - self._zlim = zlim - self._elev = elev - self._azim = azim - - # Trial and time indices _validate_type(trial_idx, int, 'trial_idx') _validate_type(time_idx, int, 'time_idx') - self._trial_idx = trial_idx - self._time_idx = time_idx + def _check_network_simulation(self): + times = None + vsec_recorded = False + # Check if network simulated + if self.net.cell_response is not None: + times = self.net.cell_response.times - # Get voltage data and corresponding colors - self.vsec_array = self._get_voltages() - self.color_array = self._colormap(self.vsec_array) + # Check if voltage recorded + if self.net._params['record_vsec'] == 'all': + vsec_recorded = True + return times, vsec_recorded + def _initialize_plots(self): + import matplotlib.pyplot as plt # Create figure - if ax is None: + if self.ax is None: self.fig = plt.figure() self.ax = self.fig.add_subplot(projection='3d') self.ax.set_facecolor(self._bg_color) - else: - self.ax = ax - self.fig = None + self._init_network_plot() self._update_axes() - _validate_type(colorbar, bool, 'colorbar') - self._colorbar = colorbar - if self._colorbar: - self._update_colorbar() - else: - self._cbar = None - def _get_voltages(self): vsec_list = list() for cell_type in self.net.cell_types: @@ -1554,6 +1561,10 @@ def trial_idx(self): @trial_idx.setter def trial_idx(self, trial_idx): _validate_type(trial_idx, int, 'trial_idx') + if not self._vsec_recorded: + raise RuntimeError("Network must be simulated with" + "`simulate_dipole(record_vsec='all')` before" + "setting `trial_idx`.") self._trial_idx = trial_idx self.vsec_array = self._get_voltages() self.color_array = self._colormap(self.vsec_array) @@ -1566,6 +1577,10 @@ def time_idx(self): @time_idx.setter def time_idx(self, time_idx): _validate_type(time_idx, (int, np.integer), 'time_idx') + if not self._vsec_recorded: + raise RuntimeError("Network must be simulated with" + "`simulate_dipole(record_vsec='all')` before" + "setting `time_idx`.") self._time_idx = time_idx self._update_section_voltages(self._time_idx) From 3b80ceffa55e2ef1cc64367762ca47c1ab83b694 Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 14:54:14 -0400 Subject: [PATCH 27/30] test animation export file exists --- hnn_core/tests/test_viz.py | 10 ++++++++-- hnn_core/viz.py | 6 +++--- 2 files changed, 11 insertions(+), 5 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index a671dda50..d6d2579f0 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -329,14 +329,20 @@ def test_network_plotter_setter(setup_net): plt.close('all') -def test_network_plotter_export(setup_net): +def test_network_plotter_export(tmp_path, setup_net): """Test NetworkPlotter class export methods.""" net = setup_net _ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=1, record_vsec='all') net_plot = NetworkPlotter(net) + # Check no file is already written + path_out = tmp_path / 'demo.gif' + assert not path_out.is_file() + # Test animation export and voltage plotting - net_plot.export_movie('demo.gif', dpi=200, decim=100, writer='pillow') + net_plot.export_movie(path_out, dpi=200, decim=100, writer='pillow') + + assert path_out.is_file() plt.close('all') diff --git a/hnn_core/viz.py b/hnn_core/viz.py index 312425f6e..e022f531f 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -921,11 +921,11 @@ def plot_cell_morphology( for sec_name, section in cell.sections.items(): linewidth = _linewidth_from_data_units(ax, section.diam) end_pts = section.end_pts + dx = pos[0] - cell.sections['soma'].end_pts[0][0] + dy = pos[1] - cell.sections['soma'].end_pts[0][1] + dz = pos[2] - cell.sections['soma'].end_pts[0][2] xs, ys, zs = list(), list(), list() for pt in end_pts: - dx = pos[0] - cell.sections['soma'].end_pts[0][0] - dy = pos[1] - cell.sections['soma'].end_pts[0][1] - dz = pos[2] - cell.sections['soma'].end_pts[0][2] xs.append(pt[0] + dx) ys.append(pt[1] + dz) zs.append(pt[2] + dy) From 71e1d9ab0a96cfba6cd29b17adadc91bb3b7823e Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 15 May 2024 16:19:12 -0400 Subject: [PATCH 28/30] loop for init test --- hnn_core/tests/test_viz.py | 25 +++++-------------------- 1 file changed, 5 insertions(+), 20 deletions(-) diff --git a/hnn_core/tests/test_viz.py b/hnn_core/tests/test_viz.py index d6d2579f0..bb4020d50 100644 --- a/hnn_core/tests/test_viz.py +++ b/hnn_core/tests/test_viz.py @@ -233,26 +233,11 @@ def test_network_plotter_init(setup_net): """Test init keywords of NetworkPlotter class.""" net = setup_net # test NetworkPlotter class - with pytest.raises(TypeError, match='xlim must be'): - _ = NetworkPlotter(net, xlim='blah') - with pytest.raises(TypeError, match='ylim must be'): - _ = NetworkPlotter(net, ylim='blah') - with pytest.raises(TypeError, match='zlim must be'): - _ = NetworkPlotter(net, zlim='blah') - with pytest.raises(TypeError, match='elev must be'): - _ = NetworkPlotter(net, elev='blah') - with pytest.raises(TypeError, match='azim must be'): - _ = NetworkPlotter(net, azim='blah') - with pytest.raises(TypeError, match='vmin must be'): - _ = NetworkPlotter(net, vmin='blah') - with pytest.raises(TypeError, match='vmax must be'): - _ = NetworkPlotter(net, vmax='blah') - with pytest.raises(TypeError, match='trial_idx must be'): - _ = NetworkPlotter(net, trial_idx=1.0) - with pytest.raises(TypeError, match='time_idx must be'): - _ = NetworkPlotter(net, time_idx=1.0) - with pytest.raises(TypeError, match='colorbar must be'): - _ = NetworkPlotter(net, colorbar='blah') + args = ['xlim', 'ylim', 'zlim', 'elev', 'azim', 'vmin', 'vmax', + 'trial_idx', 'time_idx', 'colorbar'] + for arg in args: + with pytest.raises(TypeError, match=f'{arg} must be'): + net_plot = NetworkPlotter(net, **{arg: 'blah'}) net_plot = NetworkPlotter(net) From 15d856ea48a59238ed1359d596b16e316ae62441 Mon Sep 17 00:00:00 2001 From: Nicholas Tolley Date: Wed, 22 May 2024 18:48:45 -0400 Subject: [PATCH 29/30] update whats_new --- doc/whats_new.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 6a8b17bfb..8c6854552 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -58,6 +58,9 @@ Changelog - Added feature to read/write :class:`~hnn_core.Network` configurations to json, by `George Dang`_ and `Rajat Partani`_ in :gh:`757` +- Added :class:`~hnn_core/viz/NetworkPlotter` to visualize and animate network simulations, + by `Nick Tolley`_ in :gh:`649` + Bug ~~~ - Fix inconsistent connection mapping from drive gids to cell gids, by From e9a94ef69633ff2f306e3e8c642e72b6bdc0e478 Mon Sep 17 00:00:00 2001 From: Nicholas Tolley <55253912+ntolley@users.noreply.github.com> Date: Thu, 23 May 2024 14:05:52 -0400 Subject: [PATCH 30/30] Apply suggestions from code review Co-authored-by: Ryan Thorpe --- doc/whats_new.rst | 2 +- hnn_core/viz.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/doc/whats_new.rst b/doc/whats_new.rst index 8c6854552..62c0fb3e9 100644 --- a/doc/whats_new.rst +++ b/doc/whats_new.rst @@ -59,7 +59,7 @@ Changelog json, by `George Dang`_ and `Rajat Partani`_ in :gh:`757` - Added :class:`~hnn_core/viz/NetworkPlotter` to visualize and animate network simulations, - by `Nick Tolley`_ in :gh:`649` + by `Nick Tolley`_ in :gh:`649`. Bug ~~~ diff --git a/hnn_core/viz.py b/hnn_core/viz.py index e022f531f..066cc74e1 100644 --- a/hnn_core/viz.py +++ b/hnn_core/viz.py @@ -1503,7 +1503,7 @@ def zlim(self, zlim): self._zlim = zlim self.ax.set_zlim(self._zlim) - # Eleevation and azimuth of 3D viewpoint + # Elevation and azimuth of 3D viewpoint @property def elev(self): return self._elev