Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MRG] NetworkPlot class to enable interactive visualizations and animations #649

Merged
merged 30 commits into from
May 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
a712ccb
Add cell position argument ot cell plots
ntolley May 19, 2023
4e09d90
Add position tests
ntolley May 19, 2023
d55a831
WIP
ntolley May 20, 2023
e7ab8ff
start notebook
ntolley May 22, 2023
d058453
Start networkplot class
ntolley May 22, 2023
3693b0d
Functioning NetworkPlot class
ntolley May 22, 2023
c88bcbb
Update demo code
ntolley May 22, 2023
7e37564
Fix cell plot test
ntolley May 22, 2023
983929c
First pass at export function, fix vmin and vmax
ntolley May 23, 2023
5821004
Better export func
ntolley Jul 4, 2023
70e8881
Type checks for input args
ntolley Jul 5, 2023
6c21952
Better docs
ntolley Jul 5, 2023
c079edd
Make time_idx accept np.int
ntolley Sep 22, 2023
2a5fedd
Add self.ax
ntolley Sep 22, 2023
326af60
Add more type checks and simulation conditions
ntolley Sep 24, 2023
3e37322
Make update_voltages private
ntolley Sep 25, 2023
709788d
Update example script
ntolley Sep 25, 2023
6b58aea
formatting
ntolley Nov 20, 2023
bb568cf
respond to reviews and add colorbar functionality
ntolley Nov 20, 2023
1a5697b
add to api.rst
ntolley Nov 20, 2023
68fc6c6
Fix test with smaller net
ntolley May 15, 2024
f6ccccc
Use mesh_shape in example notebook
ntolley May 15, 2024
dcd55aa
update plot_morphology docstring
ntolley May 15, 2024
dc6e2f0
refactor network_plotter tests
ntolley May 15, 2024
303838d
better test readability
ntolley May 15, 2024
7d49391
refactor init
ntolley May 15, 2024
3b80cef
test animation export file exists
ntolley May 15, 2024
71e1d9a
loop for init test
ntolley May 15, 2024
15d856e
update whats_new
ntolley May 22, 2024
e9a94ef
Apply suggestions from code review
ntolley May 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ Visualization (:py:mod:`hnn_core.viz`):
plot_connectivity_matrix
plot_laminar_lfp
plot_laminar_csd
NetworkPlotter

Parallel backends (:py:mod:`hnn_core.parallel_backends`):
---------------------------------------------------------
Expand Down
3 changes: 3 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ Changelog
- Added feature to read/write :class:`~hnn_core.Network` configurations to
json, by `George Dang`_ and `Rajat Partani`_ in :gh:`757`

- Added :class:`~hnn_core/viz/NetworkPlotter` to visualize and animate network simulations,
by `Nick Tolley`_ in :gh:`649`.

Bug
~~~
- Fix inconsistent connection mapping from drive gids to cell gids, by
Expand Down
64 changes: 64 additions & 0 deletions examples/howto/plot_hnn_animation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""
================================
06. Animating HNN simulations
================================

This example demonstrates how to animate HNN simulations
"""

# Author: Nick Tolley <[email protected]>


###############################################################################
# First, we'll import the necessary modules for instantiating a network and
# running a simulation that we would like to animate.
import os.path as op

import hnn_core
from hnn_core import jones_2009_model, simulate_dipole, read_params
from hnn_core.network_models import add_erp_drives_to_jones_model

###############################################################################
# We begin by instantiating the network. For this example, we will reduce the
# number of cells in the network to speed up the simulations.
net = jones_2009_model(mesh_shape=(3, 3))

# Note that we move the cells further apart to allow better visualization of
# the network (default inplane_distance=1.0 µm).
net.set_cell_positions(inplane_distance=300)

###############################################################################
# The :class:`hnn_core.viz.NetworkPlotter` class can be used to visualize
# the 3D structure of the network.
from hnn_core.viz import NetworkPlotter

net_plot = NetworkPlotter(net)
net_plot.fig
jasmainak marked this conversation as resolved.
Show resolved Hide resolved

###############################################################################
# We can also visualize the network from another angle by adjusting the
# azimuth and elevation parameters.
net_plot.azim = 45
net_plot.elev = 40
net_plot.fig

###############################################################################
# Next we add event related potential (ERP) producing drives to the network
# and run the simulation (see
# :ref:`evoked example <sphx_glr_auto_examples_plot_simulate_evoked.py>`
# for more details).
# To visualize the membrane potential of cells in the
# network, we need use `simulate_dipole(..., record_vsec='all')` which turns
# on the recording of voltages in all sections of all cells in the network.
add_erp_drives_to_jones_model(net)
ntolley marked this conversation as resolved.
Show resolved Hide resolved
dpl = simulate_dipole(net, tstop=170, record_vsec='all')
net_plot = NetworkPlotter(net) # Reinitialize plotter with simulated network

###############################################################################
# Finally, we can animate the simulation using the `export_movie()` method. We
# can adjust the xyz limits of the plot to better visualize the network.
net_plot.xlim = (400, 1600)
net_plot.ylim = (400, 1600)
net_plot.zlim = (-500, 1600)
net_plot.azim = 225
net_plot.export_movie('animation_demo.gif', dpi=100, fps=30, interval=100)
jasmainak marked this conversation as resolved.
Show resolved Hide resolved
18 changes: 15 additions & 3 deletions hnn_core/cell.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,7 +863,9 @@ def parconnect_from_src(self, gid_presyn, nc_dict, postsyn,

return nc

def plot_morphology(self, ax=None, color=None, show=True):
def plot_morphology(self, ax=None, color=None, pos=(0, 0, 0),
xlim=(-250, 150), ylim=(-100, 100), zlim=(-100, 1200),
show=True):
ntolley marked this conversation as resolved.
Show resolved Hide resolved
"""Plot the cell morphology.

Parameters
Expand All @@ -875,8 +877,17 @@ def plot_morphology(self, ax=None, color=None, show=True):
color indicated by str. If dict, colors of individual sections
can be specified. Must have a key for every section in cell as
defined in the `Cell.sections` attribute.
| Ex: ``{'apical_trunk': 'r', 'soma': 'b', ...}``

| Ex: ``{'apical_trunk': 'r', 'soma': 'b', ...}``
pos : tuple of int or float | None
Position of cell soma. Must be a tuple of 3 elements for the
(x, y, z) position of the soma in 3D space. Default: (0, 0, 0)
xlim : tuple of int | tuple of float
x limits of plot window. Default (-250, 150)
ylim : tuple of int | tuple of float
y limits of plot window. Default (-100, 100)
zlim : tuple of int | tuple of float
z limits of plot window. Default (-100, 1200)
show : bool
If True, show the plot

Expand All @@ -885,7 +896,8 @@ def plot_morphology(self, ax=None, color=None, show=True):
axes : instance of Axes3D
The matplotlib 3D axis handle.
"""
return plot_cell_morphology(self, ax=ax, color=color, show=show)
return plot_cell_morphology(self, ax=ax, color=color, pos=pos,
xlim=xlim, ylim=ylim, zlim=zlim, show=show)

def _update_section_end_pts_L(self, node, dpt):
if self.cell_tree is None:
Expand Down
143 changes: 130 additions & 13 deletions hnn_core/tests/test_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,32 @@
import matplotlib
from matplotlib import backend_bases
import matplotlib.pyplot as plt
from matplotlib.colorbar import Colorbar

import numpy as np
from numpy.testing import assert_allclose
import pytest

import hnn_core
from hnn_core import read_params, jones_2009_model
from hnn_core.viz import plot_cells, plot_dipole, plot_psd, plot_tfr_morlet
from hnn_core.viz import plot_connectivity_matrix, plot_cell_connectivity
from hnn_core.viz import (plot_cells, plot_dipole, plot_psd, plot_tfr_morlet,
plot_connectivity_matrix, plot_cell_connectivity,
NetworkPlotter)
from hnn_core.dipole import simulate_dipole

matplotlib.use('agg')


@pytest.fixture
def setup_net():
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = jones_2009_model(params, mesh_shape=(3, 3))

return net


def _fake_click(fig, ax, point, button=1):
"""Fake a click at a point within axes."""
x, y = ax.transData.transform_point(point)
Expand All @@ -26,12 +39,9 @@ def _fake_click(fig, ax, point, button=1):
fig.canvas.callbacks.process('button_press_event', button_press_event)


def test_network_visualization():
def test_network_visualization(setup_net):
"""Test network visualisations."""
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = jones_2009_model(params, mesh_shape=(3, 3))
net = setup_net
plot_cells(net)
ax = net.cell_types['L2_pyramidal'].plot_morphology()
assert len(ax.lines) == 8
Expand Down Expand Up @@ -84,6 +94,13 @@ def test_network_visualization():
with pytest.raises(TypeError,
match="'ax' to be an instance of Axes3D, but got Axes"):
plot_cells(net, ax=axes, show=False)
cell_type.plot_morphology(pos=(1.0, 2.0, 3.0))
with pytest.raises(TypeError, match='pos must be'):
cell_type.plot_morphology(pos=123)
with pytest.raises(ValueError, match='pos must be a tuple of 3 elements'):
cell_type.plot_morphology(pos=(1, 2, 3, 4))
with pytest.raises(TypeError, match='pos\\[idx\\] must be'):
cell_type.plot_morphology(pos=(1, '2', 3))

plt.close('all')

Expand All @@ -104,12 +121,9 @@ def test_network_visualization():
plt.close('all')


def test_dipole_visualization():
def test_dipole_visualization(setup_net):
"""Test dipole visualisations."""
hnn_core_root = op.dirname(hnn_core.__file__)
params_fname = op.join(hnn_core_root, 'param', 'default.json')
params = read_params(params_fname)
net = jones_2009_model(params, mesh_shape=(3, 3))
net = setup_net
weights_ampa = {'L2_pyramidal': 5.4e-5, 'L5_pyramidal': 5.4e-5}
syn_delays = {'L2_pyramidal': 0.1, 'L5_pyramidal': 1.}

Expand All @@ -125,7 +139,7 @@ def test_dipole_visualization():
weights_ampa=weights_ampa, synaptic_delays=syn_delays,
event_seed=14)

dpls = simulate_dipole(net, tstop=100., n_trials=2)
dpls = simulate_dipole(net, tstop=100., n_trials=2, record_vsec='all')
fig = dpls[0].plot() # plot the first dipole alone
axes = fig.get_axes()[0]
dpls[0].copy().smooth(window_len=10).plot(ax=axes) # add smoothed versions
Expand Down Expand Up @@ -212,5 +226,108 @@ def test_dipole_visualization():
'beta_dist': 'g'})
with pytest.raises(ValueError, match="'beta_dist' must be"):
net.cell_response.plot_spikes_hist(color={'beta_prox': 'r'})
plt.close('all')


def test_network_plotter_init(setup_net):
"""Test init keywords of NetworkPlotter class."""
net = setup_net
# test NetworkPlotter class
jasmainak marked this conversation as resolved.
Show resolved Hide resolved
args = ['xlim', 'ylim', 'zlim', 'elev', 'azim', 'vmin', 'vmax',
'trial_idx', 'time_idx', 'colorbar']
for arg in args:
with pytest.raises(TypeError, match=f'{arg} must be'):
net_plot = NetworkPlotter(net, **{arg: 'blah'})
jasmainak marked this conversation as resolved.
Show resolved Hide resolved

net_plot = NetworkPlotter(net)

assert net_plot.vsec_array.shape == (159, 1)
assert net_plot.color_array.shape == (159, 1, 4)
assert net_plot._vsec_recorded is False
plt.close('all')


def test_network_plotter_simulation(setup_net):
"""Test NetworkPlotter class simulation warnings."""
net = setup_net
net_plot = NetworkPlotter(net)
# Errors if vsec isn't recorded
with pytest.raises(RuntimeError, match='Network must be simulated'):
net_plot.export_movie('demo.gif', dpi=200)

# Errors if vsec isn't recorded with record_vsec='all'
_ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='soma')
net_plot = NetworkPlotter(net)

assert net_plot.vsec_array.shape == (159, 1)
assert net_plot.color_array.shape == (159, 1, 4)
assert net_plot._vsec_recorded is False

with pytest.raises(RuntimeError, match='Network must be simulated'):
net_plot.export_movie('demo.gif', dpi=200)

net = setup_net
_ = simulate_dipole(net, dt=0.5, tstop=10, record_vsec='all', n_trials=2)
net_plot = NetworkPlotter(net)
# setter/getter test for time_idx and trial_idx
net_plot.time_idx = 5
assert net_plot.time_idx == 5
net_plot.trial_idx = 1
assert net_plot.trial_idx == 1

assert net_plot.vsec_array.shape == (159, 21)
assert net_plot.color_array.shape == (159, 21, 4)
assert net_plot._vsec_recorded is True
assert isinstance(net_plot._cbar, Colorbar)
plt.close('all')


def test_network_plotter_setter(setup_net):
"""Test NetworkPlotter class setters and getters."""
net = setup_net
net_plot = NetworkPlotter(net)
# Type check errors
args = ['xlim', 'ylim', 'zlim', 'elev', 'azim', 'vmin', 'vmax',
'trial_idx', 'time_idx', 'colorbar']
for arg in args:
with pytest.raises(TypeError, match=f'{arg} must be'):
setattr(net_plot, arg, 'blah')

# Check that the setters and getters work
arg_dict = {'xlim': (-100, 100), 'ylim': (-100, 100), 'zlim': (-100, 100),
'elev': 10, 'azim': 10, 'vmin': 0, 'vmax': 100,
'bgcolor': 'white', 'voltage_colormap': 'jet',
'colorbar': False}
for arg, val in arg_dict.items():
setattr(net_plot, arg, val)
assert getattr(net_plot, arg) == val

assert net_plot._cbar is None
assert net_plot.fig.get_facecolor() == (1.0, 1.0, 1.0, 1.0)

# time_idx setter should raise an error if network is not simulated
with pytest.raises(RuntimeError, match='Network must be simulated'):
net_plot.time_idx = 5

with pytest.raises(RuntimeError, match='Network must be simulated'):
net_plot.trial_idx = 1
plt.close('all')


def test_network_plotter_export(tmp_path, setup_net):
"""Test NetworkPlotter class export methods."""
net = setup_net
_ = simulate_dipole(net, dt=0.5, tstop=10, n_trials=1,
record_vsec='all')
net_plot = NetworkPlotter(net)

# Check no file is already written
path_out = tmp_path / 'demo.gif'
assert not path_out.is_file()

# Test animation export and voltage plotting
net_plot.export_movie(path_out, dpi=200, decim=100, writer='pillow')

assert path_out.is_file()

plt.close('all')
Loading
Loading