Skip to content

Commit

Permalink
adding support for spike_times data in gui (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
carsen-stringer committed Oct 18, 2024
1 parent aa5fe2d commit aca7ec9
Show file tree
Hide file tree
Showing 7 changed files with 133 additions and 15 deletions.
12 changes: 12 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,18 @@ Most of the time you will input to `Rastermap().fit` a matrix of neurons by time
* **itrain** : array, shape (n_features,) (optional, default None)
fit embedding on timepoints itrain only

If you have a `spike_times.npy` and `spike_clusters.npy`, create your time-binned data
matrix with, where the bin size `st_bin` is in milliseconds (assuming your spike times are in seconds):

```
from rastermap import io
# bin spike times into neurons by time matrix
data = io.load_spike_times("spike_times.npy", "spike_clusters.npy", st_bin=100)
```

You can also load these matrices into the GUI with the `File > Load spike_times...` option.

# Settings

These are inputs to the `Rastermap` class initialization, the settings are sorted in order of importance
Expand Down
22 changes: 16 additions & 6 deletions rastermap/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import argparse
import os
from rastermap import Rastermap
from rastermap.io import load_activity
from rastermap.io import load_activity, load_spike_times

try:
from rastermap.gui import gui
Expand All @@ -24,15 +24,23 @@
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="spikes")
parser.add_argument("--S", default=[], type=str, help="spiking matrix")
parser.add_argument("--spike_times", default=[], type=str, help="spike_times.npy")
parser.add_argument("--spike_clusters", default=[], type=str, help="spike_clusters.npy")
parser.add_argument("--st_bin", default=100, type=float, help="bin size in milliseconds for spike times")
parser.add_argument("--proc", default=[], type=str,
help="processed data file 'embedding.npy'")
parser.add_argument("--ops", default=[], type=str, help="options file 'ops.npy'")
parser.add_argument("--iscell", default=[], type=str,
help="which cells to select for processing")
args = parser.parse_args()

if len(args.ops) > 0 and len(args.S) > 0:
X, Usv, Vsv, xy = load_activity(args.S)
if len(args.ops) > 0 and (len(args.S) > 0 or
(len(args.spike_times) > 0 and len(args.spike_clusters) > 0)):
if len(args.S) > 0:
X, Usv, Vsv, xy = load_activity(args.S)
else:
Usv, Vsv, xy = None, None, None
X = load_spike_times(args.spike_times, args.spike_clusters, args.st_bin)
ops = np.load(args.ops, allow_pickle=True).item()
if len(args.iscell) > 0:
iscell = np.load(args.iscell)
Expand Down Expand Up @@ -62,14 +70,16 @@
model.fit(data=X, Usv=Usv, Vsv=Vsv)

proc = {
"filename": args.S,
"save_path": os.path.split(args.S)[0],
"filename": args.S if len(args.S) > 0 else args.spike_times,
"filename_cluid": args.spike_clusters if args.spike_clusters else None,
"st_bin": args.st_bin if args.spike_clusters else None,
"save_path": os.path.split(args.S)[0] if args.S else os.path.split(args.spike_times)[0],
"isort": model.isort,
"embedding": model.embedding,
"user_clusters": None,
"ops": ops,
}
basename, fname = os.path.split(args.S)
basename, fname = os.path.split(args.S) if args.S else os.path.split(args.spike_times)
fname = os.path.splitext(fname)[0]
try:
np.save(os.path.join(basename, f"{fname}_embedding.npy"), proc)
Expand Down
1 change: 1 addition & 0 deletions rastermap/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def __init__(self, filename=None, proc=False):
# Default variables
self.tpos = -0.5
self.tsize = 1
self.from_spike_times = False
self.reset_variables()

self.init_time_roi()
Expand Down
83 changes: 79 additions & 4 deletions rastermap/gui/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import os
import numpy as np
from qtpy import QtGui, QtCore, QtWidgets
from qtpy.QtWidgets import QFileDialog, QInputDialog, QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox
from qtpy.QtWidgets import QFileDialog, QDialog, QInputDialog, QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox
import pyqtgraph as pg
from scipy.stats import zscore
from scipy.sparse import csr_array
import scipy.io as sio
from . import guiparts
from ..io import _load_iscell, _load_stat, load_activity
from ..io import _load_iscell, _load_stat, load_activity, load_spike_times

def _load_activity_gui(parent, X, Usv, Vsv, xy):
parent.reset_variables()
Expand Down Expand Up @@ -54,6 +55,66 @@ def _load_activity_gui(parent, X, Usv, Vsv, xy):
parent.sorting = np.arange(0, parent.n_samples).astype(np.int64)
_load_sp(parent)


class SpikeTimeLoad(QDialog):
def __init__(self, parent=None):
super().__init__()
self.parent = parent
self.layout = QGridLayout(self)
self.files, self.file_labels = [], []
lbl = ["spike_times", "spike_clusters"]
for j in range(2):
self.files.append(QPushButton(f"Choose {lbl[j]} file", self))
self.files[-1].clicked.connect(lambda state, idx=j: self.get_file(idx))
self.layout.addWidget(self.files[-1], j, 0, 1, 1)
self.file_labels.append(QLabel("", self))
self.layout.addWidget(self.file_labels[-1], j, 1, 1, 1)

# time bin size lineedit
self.time_bin = QLineEdit("100", self)
self.layout.addWidget(self.time_bin, 2, 1, 1, 1)
self.layout.addWidget(QLabel("Time bin size (in millisec.)", self), 2, 0, 1, 1)

# Submit button
self.submit_button = QPushButton("Submit", self)
self.submit_button.clicked.connect(self.submit)
self.layout.addWidget(self.submit_button, 3, 0, 1, 2)

self.setWindowTitle("spike time input")
self.show()

def get_file(self, j):
name = QFileDialog.getOpenFileName(self, "Open *.npy or *.mat",
filter="*.npy *.mat")
self.file_labels[j].setText(name[0])

def submit(self):
if not self.file_labels[0].text() or not self.file_labels[1].text():
QMessageBox.critical(self, 'Error', 'Both files are required!')
else:
fname = self.file_labels[0].text()
fname_cluid = self.file_labels[1].text()
st_bin = float(self.time_bin.text())
st_bin = min(5000, max(5, st_bin))
print(f"setting bin size to {st_bin} for visualization")
_load_spike_times(self.parent, fname, fname_cluid, st_bin)

def _load_spike_times(parent, fname, fname_cluid, st_bin):
spks = load_spike_times(fname, fname_cluid, st_bin)
parent.fname = fname
parent.fname_cluid = fname_cluid
parent.st_bin = st_bin
parent.from_spike_times = True
_load_activity_gui(parent, spks, None, None, None)

def load_st_clu(parent, name=None):
""" load spike times of neurons (*.npy or *.mat) """
if name is None:
st = SpikeTimeLoad(parent)
st.exec_()



def load_mat(parent, name=None):
""" load data matrix of neurons by time (*.npy or *.mat)
Expand Down Expand Up @@ -361,13 +422,27 @@ def load_proc(parent, name=None):
else:
print(f"ERROR: {parent.proc['filename']} not found")
return

if parent.proc["filename_cluid"]:
if os.path.exists(parent.proc["filename_cluid"]):
parent.fname_cluid = parent.proc["filename_cluid"]
elif os.path.exists(os.path.join(foldername, filename)):
parent.fname_cluid = os.path.join(foldername, filename)
else:
print(f"ERROR: {parent.proc['filename_cluid']} not found")
return
parent.st_bin = parent.proc["st_bin"]


isort = parent.proc["isort"]
y = parent.proc["embedding"]
ops = parent.proc["ops"]
user_clusters = parent.proc.get("user_clusters", None)

X, Usv, Vsv, xy = load_activity(parent.fname)
if parent.proc["filename_cluid"]:
Usv, Vsv, xy = None, None, None
X = load_spike_times(parent.fname, parent.fname_cluid, parent.st_bin)
else:
X, Usv, Vsv, xy = load_activity(parent.fname)
_load_activity_gui(parent, X, Usv, Vsv, xy)

except Exception as e:
Expand Down
14 changes: 10 additions & 4 deletions rastermap/gui/menus.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,26 +15,32 @@ def mainmenu(parent):

file_menu = main_menu.addMenu("&File")

loadMat = QAction("&Load data matrix", parent)
loadMat = QAction("&Load data matrix (neurons by time)", parent)
loadMat.setShortcut("Ctrl+L")
loadMat.triggered.connect(lambda: io.load_mat(parent, name=None))
parent.addAction(loadMat)
file_menu.addAction(loadMat)

parent.loadXY = QAction("&Load xy(z) positions of neurons", parent)
loadSt = QAction("Load spike_times and spike_&Clusters", parent)
loadSt.setShortcut("Ctrl+C")
loadSt.triggered.connect(lambda: io.load_st_clu(parent, name=None))
parent.addAction(loadSt)
file_menu.addAction(loadSt)

parent.loadXY = QAction("Load &XY(z) positions of neurons", parent)
parent.loadXY.setShortcut("Ctrl+X")
parent.loadXY.triggered.connect(lambda: io.load_neuron_pos(parent))
parent.addAction(parent.loadXY)
file_menu.addAction(parent.loadXY)

# load Z-stack
parent.loadProc = QAction("&Load z-stack (mean images)", parent)
parent.loadProc = QAction("Load &Z-stack (mean images)", parent)
parent.loadProc.setShortcut("Ctrl+Z")
parent.loadProc.triggered.connect(lambda: io.load_zstack(parent, name=None))
parent.addAction(parent.loadProc)
file_menu.addAction(parent.loadProc)

parent.loadNd = QAction("Load &n-d variable (times or cont.)", parent)
parent.loadNd = QAction("Load &N-d variable (times or cont.)", parent)
parent.loadNd.setShortcut("Ctrl+N")
parent.loadNd.triggered.connect(lambda: io.get_behav_data(parent))
parent.loadNd.setEnabled(False)
Expand Down
5 changes: 4 additions & 1 deletion rastermap/gui/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,10 @@ def run_RMAP(self, parent):
ops_path = os.path.join(os.getcwd(), "rmap_ops.npy")
np.save(ops_path, self.ops)
print("Running rastermap with command:")
cmd = f"-u -W ignore -m rastermap --ops {ops_path} --S {parent.fname}"
if parent.from_spike_times:
cmd = f"-u -W ignore -m rastermap --ops {ops_path} --spike_times {parent.fname} --spike_clusters {parent.fname_cluid} --st_bin {parent.st_bin}"
else:
cmd = f"-u -W ignore -m rastermap --ops {ops_path} --S {parent.fname}"
if parent.file_iscell is not None:
cmd += f" --iscell {parent.file_iscell}"
print("python " + cmd)
Expand Down
11 changes: 11 additions & 0 deletions rastermap/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import numpy as np
import scipy.io as sio
from scipy.stats import zscore
from scipy.sparse import csr_array

def _load_dict(dat, keys):
X, Usv, Vsv, xpos, ypos, xy = None, None, None, None, None, None
Expand Down Expand Up @@ -156,6 +157,16 @@ def load_activity(filename):

return X, Usv, Vsv, xy

def load_spike_times(fname, fname_cluid, st_bin=100):
print("Loading " + fname)
st = np.load(fname).squeeze()
clu = np.load(fname_cluid).squeeze()
if len(st) != len(clu):
raise ValueError("spike times and clusters must have same length")
spks = csr_array((np.ones(len(st), "uint8"),
(clu, np.floor(st / st_bin * 1000).astype("int"))))
spks = spks.todense().astype("float32")
return spks

def _cell_center(voxel_mask):
x = np.median(np.array([v[0] for v in voxel_mask]))
Expand Down

0 comments on commit aca7ec9

Please sign in to comment.