Skip to content

Commit

Permalink
fixed GUI hang, plotting sub-epoch monitor updates
Browse files Browse the repository at this point in the history
  • Loading branch information
AdamStone committed Jan 10, 2015
1 parent 8abee26 commit 952903a
Showing 1 changed file with 77 additions and 56 deletions.
133 changes: 77 additions & 56 deletions pylearn2/train_extensions/live_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,29 @@
Training extension for allowing querying of monitoring values while an
experiment executes.
"""
__authors__ = "Dustin Webb"
__authors__ = "Dustin Webb, Adam Stone"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["Dustin Webb"]
__credits__ = ["Dustin Webb, Adam Stone"]
__license__ = "3-clause BSD"
__maintainer__ = "LISA Lab"
__email__ = "pylearn-dev@googlegroups"

import copy
import logging
log = logging.getLogger('LiveMonitor')

try:
import zmq
zmq_available = True
except:
except Exception:
zmq_available = False

try:
from PySide import QtCore, QtGui

import sys
import matplotlib
import numpy as np
matplotlib.use('Qt4Agg')
matplotlib.rcParams['backend.qt4'] = 'PySide'

Expand All @@ -30,13 +33,13 @@
from matplotlib.figure import Figure

qt_available = True
except:
except Exception:
qt_available = False

try:
import matplotlib.pyplot as plt
pyplot_available = True
except:
except Exception:
pyplot_available = False

from functools import wraps
Expand Down Expand Up @@ -361,53 +364,48 @@ def update_channels(self, channel_list, start=-1, end=-1, step=1):
chan.time_record += rsp_chan.time_record
chan.val_record += rsp_chan.val_record

def follow_channels(self, channel_list):
def follow_channels(self, channel_list, use_qt=False):
"""
Tracks and plots a specified set of channels in real time.
Parameters
----------
channel_list : list
A list of the channels for which data has been requested.
use_qt : bool
Use a PySide GUI for plotting, if available.
"""
if not pyplot_available:
if use_qt:
if not qt_available:
log.warning(
'follow_channels called with use_qt=True, but PySide '
'is not available. Falling back on matplotlib ion().')
else:
# only create new qt app if running the first time in session
if not hasattr(self, 'gui'):
self.gui = LiveMonitorGUI(self, channel_list)

self.gui.channel_list = channel_list
self.gui.start()

elif not pyplot_available:
raise ImportError('pyplot needs to be installed for '
'this functionality.')
plt.clf()
plt.ion()
while True:
self.update_channels(channel_list)
else:
plt.clf()
for channel_name in self.channels:
plt.plot(
self.channels[channel_name].epoch_record,
self.channels[channel_name].val_record,
label=channel_name
)
plt.legend()
plt.ion()
plt.draw()

def follow_channels_qt(self, channel_list):
"""
Tracks and plots a specified set of channels in real time using
a PySide Qt GUI.
Parameters
----------
channel_list : list
A list of the channels for which data has been requested.
"""
if not qt_available:
raise ImportError('PySide needs to be installed for ' +
'this functionality')

# only create qt app if running the first time
if not hasattr(self, 'gui'):
self.gui = LiveMonitorGUI(self, channel_list)

self.gui.channel_list = channel_list
self.gui.start()
while True:
self.update_channels(channel_list)
plt.clf()
for channel_name in self.channels:
plt.plot(
self.channels[channel_name].epoch_record,
self.channels[channel_name].val_record,
label=channel_name
)
plt.legend()
plt.ion()
plt.draw()

if qt_available:
class LiveMonitorGUI(QtGui.QMainWindow):
Expand All @@ -428,35 +426,58 @@ def __init__(self, lm, channel_list):
super(LiveMonitorGUI, self).__init__()
self.lm = lm
self.channel_list = channel_list
self.updaterThread = UpdaterThread(lm, channel_list)
self.updaterThread.updated.connect(self.refresh)
self.initUI()

def initUI(self):
self.fig = Figure(figsize=(600, 600), dpi=72,
self.resize(300, 200)
self.fig = Figure(figsize=(300, 200), dpi=72,
facecolor=(1, 1, 1), edgecolor=(0, 0, 0))
self.ax = self.fig.add_subplot(111)
self.canvas = FigureCanvas(self.fig)
self.setCentralWidget(self.canvas)

def update(self):
self.lm.update_channels(self.channel_list)
def refresh(self):
self.ax.cla() # clear previous plot
for channel_name in self.channel_list:
self.ax.plot(
self.lm.channels[channel_name].epoch_record,
self.lm.channels[channel_name].val_record,
label=channel_name
)

X = epoch_record = self.lm.channels[channel_name].epoch_record
Y = val_record = self.lm.channels[channel_name].val_record

indices = np.nonzero(np.diff(epoch_record))[0] + 1
epoch_record_split = np.split(epoch_record, indices)
val_record_split = np.split(val_record, indices)

X = np.zeros(len(epoch_record))
Y = np.zeros(len(epoch_record))

for i, epoch in enumerate(epoch_record_split):

j = i*len(epoch_record_split[0])
X[j: j + len(epoch)] = (
1.*np.arange(len(epoch)) / len(epoch) + epoch[0])
Y[j: j + len(epoch)] = val_record_split[i]

self.ax.plot(X, Y, label=channel_name)

self.ax.legend()
self.canvas.draw()

def closeEvent(self, event):
self.updateTimer.stop()
event.accept()
self.updaterThread.start()

def start(self):
self.updateTimer = QtCore.QTimer(self)
self.updateTimer.timeout.connect(self.update)
self.updateTimer.start(10000)
self.show()
self.update()
self.updaterThread.start()
self.app.exec_()

class UpdaterThread(QtCore.QThread):
updated = QtCore.Signal()

def __init__(self, lm, channel_list):
super(UpdaterThread, self).__init__()
self.lm = lm
self.channel_list = channel_list

def run(self):
self.lm.update_channels(self.channel_list) # blocking
self.updated.emit()

0 comments on commit 952903a

Please sign in to comment.