Skip to content

Commit

Permalink
respond to reviews and add colorbar functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
ntolley committed Nov 20, 2023
1 parent d97aa0e commit 9174280
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 12 deletions.
11 changes: 11 additions & 0 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import matplotlib
from matplotlib import backend_bases
import matplotlib.pyplot as plt
from matplotlib.colorbar import Colorbar

import numpy as np
from numpy.testing import assert_allclose
import pytest
Expand Down Expand Up @@ -239,6 +241,8 @@ def test_dipole_visualization():
_ = NetworkPlotter(net, trial_idx=1.0)
with pytest.raises(TypeError, match='time_idx must be'):
_ = NetworkPlotter(net, time_idx=1.0)
with pytest.raises(TypeError, match='colorbar must be'):
_ = NetworkPlotter(net, colorbar='blah')

net = jones_2009_model(params)
net_plot = NetworkPlotter(net)
Expand Down Expand Up @@ -271,6 +275,7 @@ def test_dipole_visualization():
assert net_plot.vsec_array.shape == (159, 21)
assert net_plot.color_array.shape == (159, 21, 4)
assert net_plot._vsec_recorded is True
assert isinstance(net_plot._cbar, Colorbar)

# Type check errors
with pytest.raises(TypeError, match='xlim must be'):
Expand All @@ -291,6 +296,8 @@ def test_dipole_visualization():
net_plot.trial_idx = 1.0
with pytest.raises(TypeError, match='time_idx must be'):
net_plot.time_idx = 1.0
with pytest.raises(TypeError, match='colorbar must be'):
net_plot.colorbar = 'blah'

# Check that the setters work
net_plot.xlim = (-100, 100)
Expand All @@ -305,6 +312,10 @@ def test_dipole_visualization():
net_plot.bgcolor = 'white'
net_plot.voltage_colormap = 'jet'

net_plot.colorbar = False
# check later that net._cbar is None
assert net_plot._cbar is None

# Check that the getters work
assert net_plot.xlim == (-100, 100)
assert net_plot.ylim == (-100, 100)
Expand Down
71 changes: 59 additions & 12 deletions hnn_core/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -1205,7 +1205,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,
ax : instance of matplotlib figure | None
The matplotlib axis.
colorbar : bool
If the colorbar is presented.
If True (default), adjust figure to include colorbar.
contact_labels : list
Labels associated with the contacts to plot. Passed as-is to
:func:`~matplotlib.axes.Axes.set_yticklabels`.
Expand Down Expand Up @@ -1238,8 +1238,7 @@ def plot_laminar_csd(times, data, contact_labels, ax=None, colorbar=True,


class NetworkPlotter:
"""Helper class to visualize full
morphology of HNN model.
"""Helper class to visualize full morphology of HNN model.
Parameters
----------
Expand All @@ -1256,6 +1255,8 @@ class NetworkPlotter:
Default: 50 mV
bg_color : str
Background color of ax. Default: 'black'
colorbar : bool
If True (default), adjust figure to include colorbar.
voltage_colormap : str
Colormap used for plotting voltages
Default: 'viridis'
Expand All @@ -1275,7 +1276,7 @@ class NetworkPlotter:
Index of time point plotted. Default: 0
"""
def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black',
voltage_colormap='viridis', elev=10, azim=-500,
colorbar=True, voltage_colormap='viridis', elev=10, azim=-500,
xlim=(-200, 3100), ylim=(-200, 3100), zlim=(-300, 2200),
trial_idx=0, time_idx=0):
import matplotlib.pyplot as plt
Expand Down Expand Up @@ -1303,8 +1304,8 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black',
self._bg_color = bg_color
self._voltage_colormap = voltage_colormap

self.colormaps = colormaps # Saved for voltage_colormap update method
self.colormap = colormaps[voltage_colormap]
self._colormaps = colormaps # Saved for voltage_colormap update method
self._colormap = colormaps[voltage_colormap]

# Axes limits and view positions
_validate_type(xlim, tuple, 'xlim')
Expand All @@ -1328,7 +1329,7 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black',

# Get voltage data and corresponding colors
self.vsec_array = self._get_voltages()
self.color_array = self.colormap(self.vsec_array)
self.color_array = self._colormap(self.vsec_array)

# Create figure
if ax is None:
Expand All @@ -1341,6 +1342,13 @@ def __init__(self, net, ax=None, vmin=-100, vmax=50, bg_color='black',
self._init_network_plot()
self._update_axes()

_validate_type(colorbar, bool, 'colorbar')
self._colorbar = colorbar
if self._colorbar:
self._update_colorbar()
else:
self._cbar = None

def _get_voltages(self):
vsec_list = list()
for cell_type in self.net.cell_types:
Expand Down Expand Up @@ -1389,10 +1397,23 @@ def _update_axes(self):

self.ax.view_init(self._elev, self._azim)

def _update_colorbar(self):
import matplotlib.pyplot as plt
import matplotlib.colors as mc

fig = self.ax.get_figure()
sm = plt.cm.ScalarMappable(
cmap=self.voltage_colormap,
norm=mc.Normalize(vmin=self.vmin, vmax=self.vmax))
self._cbar = fig.colorbar(sm, ax=self.ax)

def export_movie(self, fname, fps=30, dpi=300, decim=10,
interval=30, frame_start=0, frame_stop=None,
writer='pillow'):
"""Export movie of network activity
Parameters
----------
fname : str
Filename of exported movie
fps : int
Expand Down Expand Up @@ -1492,7 +1513,10 @@ def vmin(self, vmin):
_validate_type(vmin, (int, float), 'vmin')
self._vmin = vmin
self.vsec_array = self._get_voltages()
self.color_array = self.colormap(self.vsec_array)
self.color_array = self._colormap(self.vsec_array)
if self._colorbar:
self._cbar.remove()
self._update_colorbar()

@property
def vmax(self):
Expand All @@ -1503,7 +1527,10 @@ def vmax(self, vmax):
_validate_type(vmax, (int, float), 'vmax')
self._vmax = vmax
self.vsec_array = self._get_voltages()
self.color_array = self.colormap(self.vsec_array)
self.color_array = self._colormap(self.vsec_array)
if self._colorbar:
self._cbar.remove()
self._update_colorbar()

# Time and trial indices
@property
Expand All @@ -1515,7 +1542,7 @@ def trial_idx(self, trial_idx):
_validate_type(trial_idx, int, 'trial_idx')
self._trial_idx = trial_idx
self.vsec_array = self._get_voltages()
self.color_array = self.colormap(self.vsec_array)
self.color_array = self._colormap(self.vsec_array)

@property
def time_idx(self):
Expand Down Expand Up @@ -1548,5 +1575,25 @@ def voltage_colormap(self):
@voltage_colormap.setter
def voltage_colormap(self, voltage_colormap):
self._voltage_colormap = voltage_colormap
self.colormap = self.colormaps[self._voltage_colormap]
self.color_array = self.colormap(self.vsec_array)
self._colormap = self._colormaps[self._voltage_colormap]
self.color_array = self._colormap(self.vsec_array)
if self._colorbar:
self._cbar.remove()
self._update_colorbar()

@property
def colorbar(self):
return self._colorbar

@colorbar.setter
def colorbar(self, colorbar):
_validate_type(colorbar, bool, 'colorbar')
self._colorbar = colorbar
if self._colorbar:
# Remove old colorbar if already exists
if self._cbar is not None:
self._cbar.remove()
self._update_colorbar()
else:
self._cbar.remove()
self._cbar = None

0 comments on commit 9174280

Please sign in to comment.