diff --git a/README.md b/README.md index 61c7811..9d95865 100644 --- a/README.md +++ b/README.md @@ -198,6 +198,18 @@ Most of the time you will input to `Rastermap().fit` a matrix of neurons by time * **itrain** : array, shape (n_features,) (optional, default None) fit embedding on timepoints itrain only +If you have a `spike_times.npy` and `spike_clusters.npy`, create your time-binned data +matrix with, where the bin size `st_bin` is in milliseconds (assuming your spike times are in seconds): + +``` +from rastermap import io + +# bin spike times into neurons by time matrix +data = io.load_spike_times("spike_times.npy", "spike_clusters.npy", st_bin=100) +``` + +You can also load these matrices into the GUI with the `File > Load spike_times...` option. + # Settings These are inputs to the `Rastermap` class initialization, the settings are sorted in order of importance diff --git a/rastermap/__main__.py b/rastermap/__main__.py index b105704..9eaead7 100644 --- a/rastermap/__main__.py +++ b/rastermap/__main__.py @@ -6,7 +6,7 @@ import argparse import os from rastermap import Rastermap -from rastermap.io import load_activity +from rastermap.io import load_activity, load_spike_times try: from rastermap.gui import gui @@ -24,6 +24,9 @@ if __name__ == "__main__": parser = argparse.ArgumentParser(description="spikes") parser.add_argument("--S", default=[], type=str, help="spiking matrix") + parser.add_argument("--spike_times", default=[], type=str, help="spike_times.npy") + parser.add_argument("--spike_clusters", default=[], type=str, help="spike_clusters.npy") + parser.add_argument("--st_bin", default=100, type=float, help="bin size in milliseconds for spike times") parser.add_argument("--proc", default=[], type=str, help="processed data file 'embedding.npy'") parser.add_argument("--ops", default=[], type=str, help="options file 'ops.npy'") @@ -31,8 +34,13 @@ help="which cells to select for processing") args = parser.parse_args() - if len(args.ops) > 0 and len(args.S) > 0: - X, Usv, Vsv, xy = load_activity(args.S) + if len(args.ops) > 0 and (len(args.S) > 0 or + (len(args.spike_times) > 0 and len(args.spike_clusters) > 0)): + if len(args.S) > 0: + X, Usv, Vsv, xy = load_activity(args.S) + else: + Usv, Vsv, xy = None, None, None + X = load_spike_times(args.spike_times, args.spike_clusters, args.st_bin) ops = np.load(args.ops, allow_pickle=True).item() if len(args.iscell) > 0: iscell = np.load(args.iscell) @@ -62,14 +70,16 @@ model.fit(data=X, Usv=Usv, Vsv=Vsv) proc = { - "filename": args.S, - "save_path": os.path.split(args.S)[0], + "filename": args.S if len(args.S) > 0 else args.spike_times, + "filename_cluid": args.spike_clusters if args.spike_clusters else None, + "st_bin": args.st_bin if args.spike_clusters else None, + "save_path": os.path.split(args.S)[0] if args.S else os.path.split(args.spike_times)[0], "isort": model.isort, "embedding": model.embedding, "user_clusters": None, "ops": ops, } - basename, fname = os.path.split(args.S) + basename, fname = os.path.split(args.S) if args.S else os.path.split(args.spike_times) fname = os.path.splitext(fname)[0] try: np.save(os.path.join(basename, f"{fname}_embedding.npy"), proc) diff --git a/rastermap/gui/gui.py b/rastermap/gui/gui.py index e01f39a..5c2c0e7 100644 --- a/rastermap/gui/gui.py +++ b/rastermap/gui/gui.py @@ -178,6 +178,7 @@ def __init__(self, filename=None, proc=False): # Default variables self.tpos = -0.5 self.tsize = 1 + self.from_spike_times = False self.reset_variables() self.init_time_roi() diff --git a/rastermap/gui/io.py b/rastermap/gui/io.py index 8d95c39..0759de6 100644 --- a/rastermap/gui/io.py +++ b/rastermap/gui/io.py @@ -4,12 +4,13 @@ import os import numpy as np from qtpy import QtGui, QtCore, QtWidgets -from qtpy.QtWidgets import QFileDialog, QInputDialog, QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox +from qtpy.QtWidgets import QFileDialog, QDialog, QInputDialog, QMainWindow, QApplication, QWidget, QScrollBar, QSlider, QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, QLineEdit, QMessageBox, QGroupBox import pyqtgraph as pg from scipy.stats import zscore +from scipy.sparse import csr_array import scipy.io as sio from . import guiparts -from ..io import _load_iscell, _load_stat, load_activity +from ..io import _load_iscell, _load_stat, load_activity, load_spike_times def _load_activity_gui(parent, X, Usv, Vsv, xy): parent.reset_variables() @@ -54,6 +55,66 @@ def _load_activity_gui(parent, X, Usv, Vsv, xy): parent.sorting = np.arange(0, parent.n_samples).astype(np.int64) _load_sp(parent) + +class SpikeTimeLoad(QDialog): + def __init__(self, parent=None): + super().__init__() + self.parent = parent + self.layout = QGridLayout(self) + self.files, self.file_labels = [], [] + lbl = ["spike_times", "spike_clusters"] + for j in range(2): + self.files.append(QPushButton(f"Choose {lbl[j]} file", self)) + self.files[-1].clicked.connect(lambda state, idx=j: self.get_file(idx)) + self.layout.addWidget(self.files[-1], j, 0, 1, 1) + self.file_labels.append(QLabel("", self)) + self.layout.addWidget(self.file_labels[-1], j, 1, 1, 1) + + # time bin size lineedit + self.time_bin = QLineEdit("100", self) + self.layout.addWidget(self.time_bin, 2, 1, 1, 1) + self.layout.addWidget(QLabel("Time bin size (in millisec.)", self), 2, 0, 1, 1) + + # Submit button + self.submit_button = QPushButton("Submit", self) + self.submit_button.clicked.connect(self.submit) + self.layout.addWidget(self.submit_button, 3, 0, 1, 2) + + self.setWindowTitle("spike time input") + self.show() + + def get_file(self, j): + name = QFileDialog.getOpenFileName(self, "Open *.npy or *.mat", + filter="*.npy *.mat") + self.file_labels[j].setText(name[0]) + + def submit(self): + if not self.file_labels[0].text() or not self.file_labels[1].text(): + QMessageBox.critical(self, 'Error', 'Both files are required!') + else: + fname = self.file_labels[0].text() + fname_cluid = self.file_labels[1].text() + st_bin = float(self.time_bin.text()) + st_bin = min(5000, max(5, st_bin)) + print(f"setting bin size to {st_bin} for visualization") + _load_spike_times(self.parent, fname, fname_cluid, st_bin) + +def _load_spike_times(parent, fname, fname_cluid, st_bin): + spks = load_spike_times(fname, fname_cluid, st_bin) + parent.fname = fname + parent.fname_cluid = fname_cluid + parent.st_bin = st_bin + parent.from_spike_times = True + _load_activity_gui(parent, spks, None, None, None) + +def load_st_clu(parent, name=None): + """ load spike times of neurons (*.npy or *.mat) """ + if name is None: + st = SpikeTimeLoad(parent) + st.exec_() + + + def load_mat(parent, name=None): """ load data matrix of neurons by time (*.npy or *.mat) @@ -361,13 +422,27 @@ def load_proc(parent, name=None): else: print(f"ERROR: {parent.proc['filename']} not found") return + + if parent.proc["filename_cluid"]: + if os.path.exists(parent.proc["filename_cluid"]): + parent.fname_cluid = parent.proc["filename_cluid"] + elif os.path.exists(os.path.join(foldername, filename)): + parent.fname_cluid = os.path.join(foldername, filename) + else: + print(f"ERROR: {parent.proc['filename_cluid']} not found") + return + parent.st_bin = parent.proc["st_bin"] + isort = parent.proc["isort"] y = parent.proc["embedding"] ops = parent.proc["ops"] user_clusters = parent.proc.get("user_clusters", None) - - X, Usv, Vsv, xy = load_activity(parent.fname) + if parent.proc["filename_cluid"]: + Usv, Vsv, xy = None, None, None + X = load_spike_times(parent.fname, parent.fname_cluid, parent.st_bin) + else: + X, Usv, Vsv, xy = load_activity(parent.fname) _load_activity_gui(parent, X, Usv, Vsv, xy) except Exception as e: diff --git a/rastermap/gui/menus.py b/rastermap/gui/menus.py index ed2aa84..f4af53e 100644 --- a/rastermap/gui/menus.py +++ b/rastermap/gui/menus.py @@ -15,26 +15,32 @@ def mainmenu(parent): file_menu = main_menu.addMenu("&File") - loadMat = QAction("&Load data matrix", parent) + loadMat = QAction("&Load data matrix (neurons by time)", parent) loadMat.setShortcut("Ctrl+L") loadMat.triggered.connect(lambda: io.load_mat(parent, name=None)) parent.addAction(loadMat) file_menu.addAction(loadMat) - parent.loadXY = QAction("&Load xy(z) positions of neurons", parent) + loadSt = QAction("Load spike_times and spike_&Clusters", parent) + loadSt.setShortcut("Ctrl+C") + loadSt.triggered.connect(lambda: io.load_st_clu(parent, name=None)) + parent.addAction(loadSt) + file_menu.addAction(loadSt) + + parent.loadXY = QAction("Load &XY(z) positions of neurons", parent) parent.loadXY.setShortcut("Ctrl+X") parent.loadXY.triggered.connect(lambda: io.load_neuron_pos(parent)) parent.addAction(parent.loadXY) file_menu.addAction(parent.loadXY) # load Z-stack - parent.loadProc = QAction("&Load z-stack (mean images)", parent) + parent.loadProc = QAction("Load &Z-stack (mean images)", parent) parent.loadProc.setShortcut("Ctrl+Z") parent.loadProc.triggered.connect(lambda: io.load_zstack(parent, name=None)) parent.addAction(parent.loadProc) file_menu.addAction(parent.loadProc) - parent.loadNd = QAction("Load &n-d variable (times or cont.)", parent) + parent.loadNd = QAction("Load &N-d variable (times or cont.)", parent) parent.loadNd.setShortcut("Ctrl+N") parent.loadNd.triggered.connect(lambda: io.get_behav_data(parent)) parent.loadNd.setEnabled(False) diff --git a/rastermap/gui/run.py b/rastermap/gui/run.py index 43b6da3..f230c92 100644 --- a/rastermap/gui/run.py +++ b/rastermap/gui/run.py @@ -82,7 +82,10 @@ def run_RMAP(self, parent): ops_path = os.path.join(os.getcwd(), "rmap_ops.npy") np.save(ops_path, self.ops) print("Running rastermap with command:") - cmd = f"-u -W ignore -m rastermap --ops {ops_path} --S {parent.fname}" + if parent.from_spike_times: + cmd = f"-u -W ignore -m rastermap --ops {ops_path} --spike_times {parent.fname} --spike_clusters {parent.fname_cluid} --st_bin {parent.st_bin}" + else: + cmd = f"-u -W ignore -m rastermap --ops {ops_path} --S {parent.fname}" if parent.file_iscell is not None: cmd += f" --iscell {parent.file_iscell}" print("python " + cmd) diff --git a/rastermap/io.py b/rastermap/io.py index 0b3cb25..d01f0ed 100644 --- a/rastermap/io.py +++ b/rastermap/io.py @@ -5,6 +5,7 @@ import numpy as np import scipy.io as sio from scipy.stats import zscore +from scipy.sparse import csr_array def _load_dict(dat, keys): X, Usv, Vsv, xpos, ypos, xy = None, None, None, None, None, None @@ -156,6 +157,16 @@ def load_activity(filename): return X, Usv, Vsv, xy +def load_spike_times(fname, fname_cluid, st_bin=100): + print("Loading " + fname) + st = np.load(fname).squeeze() + clu = np.load(fname_cluid).squeeze() + if len(st) != len(clu): + raise ValueError("spike times and clusters must have same length") + spks = csr_array((np.ones(len(st), "uint8"), + (clu, np.floor(st / st_bin * 1000).astype("int")))) + spks = spks.todense().astype("float32") + return spks def _cell_center(voxel_mask): x = np.median(np.array([v[0] for v in voxel_mask]))