Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ENH] make dss_line() faster #57

Merged
merged 5 commits into from
Dec 2, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pythonpackage.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
strategy:
max-parallel: 4
matrix:
python-version: [3.6, 3.7, 3.8, 3.9]
python-version: [3.7, 3.8, 3.9]

steps:
- uses: actions/checkout@v1
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
[![DOI](https://zenodo.org/badge/117451752.svg)](https://zenodo.org/badge/latestdoi/117451752)
[![twitter](https://img.shields.io/twitter/follow/lebababa?label=Twitter&style=flat&logo=Twitter)](https://twitter.com/intent/follow?screen_name=lebababa)

Denoising tools for M/EEG processing in Python 3.6+.
Denoising tools for M/EEG processing in Python 3.7+.

> **Disclaimer:** The project mostly consists of development code, although some modules and functions are already working. Bugs and performance problems are to be expected, so use at your own risk. More tests and improvements will be added in the future. Comments and suggestions are welcome.

Expand Down
264 changes: 197 additions & 67 deletions examples/example_detrend.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion examples/example_detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
# -----------------------------------------------------------------------------
# We first try to fit a simple random walk process.
x = np.cumsum(np.random.randn(1000, 1), axis=0)
r = np.arange(1000)[:, None]
r = np.arange(1000.)[:, None]
r = np.hstack([r, r ** 2, r ** 3])
b, y = regress(x, r)

Expand Down
300 changes: 300 additions & 0 deletions examples/example_dss_line.ipynb

Large diffs are not rendered by default.

10 changes: 6 additions & 4 deletions meegkit/detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def regress(x, r, w=None, threshold=1e-7, return_mean=False):
xx = demean(x[:, i], wc) * wc

# remove channel-specific-weighted mean from regressor
r = demean(r, wc)
r = demean(r, wc, inplace=True)
rr = r * wc
V, _ = pca(rr.T @ rr, thresh=threshold)
rr = rr.dot(V)
Expand Down Expand Up @@ -282,12 +282,14 @@ def _plot_detrend(x, y, w):
f = plt.figure()
gs = GridSpec(4, 1, figure=f)
ax1 = f.add_subplot(gs[:3, 0])
plt.plot(x, label='original', color='C0')
plt.plot(y, label='detrended', color='C1')
lines = ax1.plot(x, label='original', color='C0')
plt.setp(lines[1:], label="_")
lines = ax1.plot(y, label='detrended', color='C1')
plt.setp(lines[1:], label="_")
ax1.set_xlim(0, n_times)
ax1.set_xticklabels('')
ax1.set_title('Robust detrending')
ax1.legend()
ax1.legend(fontsize='smaller')

ax2 = f.add_subplot(gs[3, 0])
ax2.pcolormesh(w.T, cmap='Greys')
Expand Down
76 changes: 47 additions & 29 deletions meegkit/dss.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
"""Denoising source separation."""
# Authors: Nicolas Barascud <[email protected]>
# Maciej Szul <[email protected]>

import numpy as np
from scipy import linalg
from scipy.signal import welch

from .tspca import tsr
from .utils import (demean, gaussfilt, mean_over_trials, pca, smooth,
from .utils import (demean, gaussfilt, matmul3d, mean_over_trials, pca, smooth,
theshapeof, tscov, wpwr)

from numpy.lib.stride_tricks import sliding_window_view


def dss1(X, weights=None, keep1=None, keep2=1e-12):
"""DSS to maximise repeatability across trials.
Expand Down Expand Up @@ -134,7 +135,8 @@ def dss0(c0, c1, keep1=None, keep2=1e-9):
return todss, fromdss, pwr0, pwr1


def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, blocksize=None,
show=False):
"""Apply DSS to remove power line artifacts.

Implements the ZapLine algorithm described in [1]_.
Expand All @@ -153,6 +155,11 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
FFT size (default=1024).
nkeep : int
Number of components to keep in DSS (default=None).
blocksize : int
If not None (default), covariance is computed on blocks of
``blocksize`` samples. This may improve performance for large datasets.
show: bool
If True, show DSS results (default=False).

Returns
-------
Expand Down Expand Up @@ -183,30 +190,45 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):

"""
if X.shape[0] < nfft:
print('reducing nfft to {}'.format(X.shape[0]))
print('Reducing nfft to {}'.format(X.shape[0]))
nfft = X.shape[0]
n_samples, n_chans, n_trials = theshapeof(X)
X = demean(X)
n_samples, n_chans, _ = theshapeof(X)
if blocksize is None:
blocksize = n_samples

# cancels line_frequency and harmonics, light lowpass
xx = smooth(X, sfreq / fline)
# Recentre data
X = demean(X, inplace=True)

# residual (X=xx+xxx), contains line and some high frequency power
xxx = X - xx
# Cancel line_frequency and harmonics + light lowpass
X_filt = smooth(X, sfreq / fline)

# reduce dimensionality to avoid overfitting
# X - X_filt results in the artifact plus some residual biological signal
X_noise = X - X_filt

# Reduce dimensionality to avoid overfitting
if nkeep is not None:
xxx_cov = tscov(xxx)[0]
V, _ = pca(xxx_cov, nkeep)
xxxx = xxx @ V
cov_X_res = tscov(X_noise)[0]
V, _ = pca(cov_X_res, nkeep)
X_noise_pca = X_noise @ V
else:
xxxx = xxx.copy()
X_noise_pca = X_noise.copy()
nkeep = n_chans

# DSS to isolate line components from residual:
# Compute blockwise covariances of raw and biased data
n_harm = np.floor((sfreq / 2) / fline).astype(int)
c0, _ = tscov(xxxx)
c1, _ = tscov(gaussfilt(xxxx, sfreq, fline, 1, n_harm=n_harm))

c0 = np.zeros((nkeep, nkeep))
c1 = np.zeros((nkeep, nkeep))
for X_block in sliding_window_view(X_noise_pca, (blocksize, nkeep),
axis=(0, 1))[::blocksize, 0]:
# if n_trials>1, reshape to (n_samples, nkeep, n_trials)
if X_block.ndim == 3:
X_block = X_block.transpose(1, 2, 0)

# bias data
c0 += tscov(X_block)[0]
c1 += tscov(gaussfilt(X_block, sfreq, fline, fwhm=1, n_harm=n_harm))[0]

# DSS to isolate line components from residual
todss, _, pwr0, pwr1 = dss0(c0, c1)

if show:
Expand All @@ -217,23 +239,19 @@ def dss_line(X, fline, sfreq, nremove=1, nfft=1024, nkeep=None, show=False):
plt.title('DSS to enhance line frequencies')
plt.show()

# Remove line components from X_noise
idx_remove = np.arange(nremove)
if X.ndim == 3:
for t in range(n_trials): # line-dominated components
xxxx[..., t] = xxxx[..., t] @ todss[:, idx_remove]
elif X.ndim == 2:
xxxx = xxxx @ todss[:, idx_remove]

xxx, _, _, _ = tsr(xxx, xxxx) # project them out
X_artifact = matmul3d(X_noise_pca, todss[:, idx_remove])
X_res = tsr(X_noise, X_artifact)[0] # project them out

# reconstruct clean signal
y = xx + xxx
artifact = X - y
y = X_filt + X_res

# Power of components
p = wpwr(X - y)[0] / wpwr(X)[0]
print('Power of components removed by DSS: {:.2f}'.format(p))
return y, artifact
# return the reconstructed clean signal, and the artifact
return y, X - y


def dss_line_iter(data, fline, sfreq, win_sz=10, spot_sz=2.5,
Expand Down
10 changes: 4 additions & 6 deletions meegkit/tspca.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,8 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):
wR = weights

# remove weighted means
X, mean1 = demean(X, wX, return_mean=True)
R = demean(R, wR)
X, mean1 = demean(X, wX, return_mean=True, inplace=True)
R = demean(R, wR, inplace=True)

# equalize power of R channels, the equalize power of the R PCs
# if R.shape[1] > 1:
Expand All @@ -182,11 +182,9 @@ def tsr(X, R, shifts=None, wX=None, wR=None, keep=None, thresh=1e-12):
y = np.zeros((n_samples_X, n_chans_X, n_trials_X))
for t in np.arange(n_trials_X):
r = multishift(R[..., t], shifts, reshape=True)
z = r @ regression
y[..., t] = X[:z.shape[0], :, t] - z

y, mean2 = demean(y, wX, return_mean=True)
y[..., t] = X[:z.shape[0], :, t] - (r @ regression)

y, mean2 = demean(y, wX, return_mean=True, inplace=True)
idx = np.arange(offset1, initial_samples - offset2)
mean_total = mean1 + mean2
weights = wR
Expand Down
4 changes: 2 additions & 2 deletions meegkit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from .denoise import (demean, find_outlier_samples, find_outlier_trials,
mean_over_trials, wpwr)
from .matrix import (fold, matmul3d, multishift, multismooth, normcol,
relshift, shift, shiftnd, theshapeof, unfold, unsqueeze,
widen_mask)
relshift, shift, shiftnd, sliding_window, theshapeof,
unfold, unsqueeze, widen_mask)
from .sig import (gaussfilt, hilbert_envelope, slope_sum, smooth,
spectral_envelope, teager_kaiser)
from .stats import (bootstrap_ci, bootstrap_snr, cronbach, rms, robust_mean,
Expand Down
23 changes: 16 additions & 7 deletions meegkit/utils/denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,32 @@
from .matrix import fold, theshapeof, unfold, _check_weights


def demean(X, weights=None, return_mean=False):
def demean(X, weights=None, return_mean=False, inplace=False):
"""Remove weighted mean over rows (samples).

Parameters
----------
X : array, shape=(n_samples, n_channels[, n_trials])
Data.
weights : array, shape=(n_samples)
return_mean : bool
If True, also return signal mean (default=False).
inplace : bool
If True, save the resulting array in X (default=False).

Returns
-------
demeaned_X : array, shape=(n_samples, n_channels[, n_trials])
X : array, shape=(n_samples, n_channels[, n_trials])
Centered data.
mn : array
Mean value.

"""
weights = _check_weights(weights, X)
ndims = X.ndim

if not inplace:
X = X.copy()
n_samples, n_chans, n_trials = theshapeof(X)
X = unfold(X)

Expand All @@ -43,18 +50,20 @@ def demean(X, weights=None, return_mean=False):
raise ValueError('Weight array should have either the same ' +
'number of columns as X array, or 1 column.')

demeaned_X = X - mn
else:
mn = np.mean(X, axis=0, keepdims=True)
demeaned_X = X - mn

# Remove mean (no copy)
X -= mn
# np.subtract(X, mn, out=X)

if n_trials > 1 or ndims == 3:
demeaned_X = fold(demeaned_X, n_samples)
X = fold(X, n_samples)

if return_mean:
return demeaned_X, mn # the_mean.shape=(1, the_mean.shape[0])
return X, mn # the_mean.shape=(1, the_mean.shape[0])
else:
return demeaned_X
return X


def mean_over_trials(X, weights=None):
Expand Down
26 changes: 16 additions & 10 deletions meegkit/utils/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from numpy.lib.stride_tricks import as_strided


def sliding_window(data, window, step=1, padded=False, axis=-1, copy=True):
def sliding_window(data, window, step=1, axis=-1, copy=True):
"""Calculate a sliding window over a signal.

Parameters
Expand All @@ -30,8 +30,8 @@ def sliding_window(data, window, step=1, padded=False, axis=-1, copy=True):

Notes
-----
- Be wary of setting `copy` to `False` as undesired sideffects with the
output values may occur.
Be wary of setting `copy` to `False` as undesired side effects with the
output values may occur.

Examples
--------
Expand Down Expand Up @@ -495,17 +495,18 @@ def unsqueeze(X):


def fold(X, epoch_size):
"""Fold 2D X into 3D."""
"""Fold 2D (n_times, n_channels) X into 3D (n_times, n_chans, n_trials)."""
if X.ndim == 1:
X = X[:, np.newaxis]
if X.ndim > 2:
raise AttributeError('X must be 2D at most')

n_chans = X.shape[0] // epoch_size
nt = X.shape[0] // epoch_size
nc = X.shape[1]
if X.shape[0] / epoch_size >= 1:
X = np.transpose(np.reshape(X, (epoch_size, n_chans, X.shape[1]),
order="F").copy(), [0, 2, 1])
return X
return X.reshape((epoch_size, nt, nc), order="F").transpose([0, 2, 1])
else:
return X


def unfold(X):
Expand Down Expand Up @@ -622,9 +623,14 @@ def matmul3d(X, mixin):
Projection.

"""
assert X.ndim == 3, 'data must be of shape (n_samples, n_chans, n_trials)'
assert mixin.ndim == 2, 'mixing matrix must be 2D'
return np.einsum('sct,ck->skt', X, mixin)

if X.ndim == 2:
return X @ mixin
elif X.ndim == 3:
return np.einsum('sct,ck->skt', X, mixin)
else:
raise RuntimeError('X must be (n_samples, n_chans, n_trials)')


def _check_shifts(shifts, allow_floats=False):
Expand Down
10 changes: 5 additions & 5 deletions meegkit/utils/sig.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,13 +345,13 @@ def gaussfilt(data, srate, f, fwhm, n_harm=1, shift=0, return_empvals=False,
fx = fx + gauss

# filter
tmp = np.fft.fft(data, axis=0)
if data.ndim == 2:
tmp *= fx[:, None]
filtdat = 2 * np.real(np.fft.ifft(
np.fft.fft(data, axis=0) * fx[:, None], axis=0))
elif data.ndim == 3:
tmp *= fx[:, None, None]

filtdat = 2 * np.real(np.fft.ifft(tmp, axis=0))
filtdat = 2 * np.real(np.fft.ifft(
np.fft.fft(data, axis=0) * fx[:, None, None],
axis=0))

if return_empvals or show:
empVals[0] = hz[idx_p]
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
numpy
numpy>=1.20
scipy
matplotlib
scikit-learn
Expand Down
Binary file added tests/data/eeg_for_trca.mat
Binary file not shown.
2 changes: 1 addition & 1 deletion tests/test_detrend.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def test_regress():
# Simple regression example, no weights
# fit random walk
y = np.cumsum(np.random.randn(1000, 1), axis=0)
x = np.arange(1000)[:, None]
x = np.arange(1000.)[:, None]
x = np.hstack([x, x ** 2, x ** 3])
[b, z] = regress(y, x)

Expand Down
Loading