Skip to content

Commit

Permalink
[MRG] Fix GUI dipole data overlay (#869)
Browse files Browse the repository at this point in the history
* test: added test_dipole_data_overlay to test a figure with both simulated data and loaded data

* fix: updated _simulate_edit_figure to reflect current figure widgets

* fix: added a function to check for existing averaged dipoles

-The GUI can average and append the average to the dipole list at an earlier stage than plotting with "data to compare" specified. This solution checks for an existing average, if not it will return an average.
-This fixes an issue where the average_dipoles function was throwing an error because it was passed a list with an averaged dipole already.

* style: removed white space in docstring

* feat: added check for empty list in _avg_dipole_check

* docs: updated whats_new.rst
  • Loading branch information
gtdang authored Aug 27, 2024
1 parent 18830b5 commit b384a10
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 6 deletions.
2 changes: 2 additions & 0 deletions doc/whats_new.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ Changelog

Bug
~~~
- Fix GUI over-plotting of loaded data where the app stalled and did not plot
RMSE, by `George Dang`_ in :gh:`869`

API
~~~
Expand Down
29 changes: 24 additions & 5 deletions hnn_core/gui/_viz_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,20 @@ def _dynamic_rerender(fig):
fig.tight_layout()


def _avg_dipole_check(dpls):
"""Check for averaged dipole, else average the trials"""
# Check if there is an averaged dipole already
if not dpls:
return None

avg_dpls = [d for d in dpls if d.nave > 1]
if avg_dpls:
dpl = avg_dpls[0]
else:
dpl = average_dipoles(dpls)
return dpl


def _plot_on_axes(b, simulations_widget, widgets_plot_type,
data_widget,
spectrogram_colormap_selection, max_spectral_frequency,
Expand Down Expand Up @@ -427,7 +441,7 @@ def _plot_on_axes(b, simulations_widget, widgets_plot_type,
t0 = 0.0
tstop = dpls_processed[-1].times[-1]
if len(dpls_processed) > 1:
dpl = average_dipoles(dpls_processed)
dpl = _avg_dipole_check(dpls_processed)
else:
dpl = dpls_processed
rmse = _rmse(dpl, target_dpl_processed, t0, tstop)
Expand Down Expand Up @@ -983,8 +997,10 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name,
Type of visualization.
preprocessing_config : dict
A dict of visualization preprocessing parameters. Allowed keys:
`dipole_smooth`, `dipole_scaling`, `max_spectral_frequency`,
`spectrogram_colormap_selection`. config could be empty: `{}`.
`dipole_smooth`, `dipole_scaling`,
`data_to_compare`, `data_smooth`, `data_scaling`
`max_spectral_frequency`, `spectrogram_colormap_selection`.
config could be empty: `{}`.
operation : str
`"plot"` if you want to plot and `"clear"` if you want to
remove previously plotted visualizations.
Expand Down Expand Up @@ -1016,8 +1032,11 @@ def _simulate_edit_figure(self, fig_name, ax_name, simulation_name,
config_name_idx = {
"dipole_smooth": 2,
"dipole_scaling": 3,
"max_spectral_frequency": 4,
"spectrogram_colormap_selection": 5,
"data_to_compare": 4,
"data_smooth": 5,
"data_scaling": 6,
"max_spectral_frequency": 7,
"spectrogram_colormap_selection": 8,
}
for conf_key, conf_val in preprocessing_config.items():
assert conf_key in config_name_idx.keys()
Expand Down
40 changes: 39 additions & 1 deletion hnn_core/tests/test_gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ def test_gui_adaptive_spectrogram(setup_gui):


def test_gui_visualization(setup_gui):
""" Tests updating a figure creates plots with data. """
"""Tests updating a figure creates plots with data."""

gui = setup_gui
gui.run_button.click()
Expand Down Expand Up @@ -712,6 +712,44 @@ def test_gui_visualization(setup_gui):
plt.close('all')


def test_dipole_data_overlay(setup_gui):
"""Tests dipole plot with a simulation and data overlay."""
gui = setup_gui

# Run simulation with 2 trials
gui.widget_ntrials.value = 2
gui.run_button.click()

# Load data
file_path = assets_path / 'test_default.csv'
gui._simulate_upload_data(file_path)

# Edit the figure with data overlay
figid = 1
figname = f'Figure {figid}'
axname = 'ax1'
gui._simulate_viz_action("edit_figure", figname,
axname, 'default', 'current dipole', {}, 'clear')
gui._simulate_viz_action("edit_figure", figname,
axname, 'default', 'current dipole',
{'data_to_compare': 'test_default'},
'plot')
ax = gui.viz_manager.figs[figid].axes[1]

# Check number of lines
# 2 trials, 1 average, 2 data (data is over-plotted twice for some reason)
# But it only appears in the legend once.
assert len(ax.lines) == 5
assert len(ax.legend_.texts) == 2
assert ax.legend_.texts[0]._text == 'default: average'
assert ax.legend_.texts[1]._text == 'test_default'

# Check RMSE is printed
assert 'RMSE(default, test_default):' in ax.texts[0]._text

plt.close('all')


def test_unlink_relink_widget():
"""Tests the unlinking and relinking of widgets decorator."""

Expand Down

0 comments on commit b384a10

Please sign in to comment.