-
Notifications
You must be signed in to change notification settings - Fork 10
/
compute_features.py
224 lines (198 loc) · 7.15 KB
/
compute_features.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
import argparse
from multiprocessing import Value
import numpy as np
import pandas as pd
from joblib import Parallel, delayed
import mne
from mne_bids import BIDSPath
import coffeine
from mne_features.feature_extraction import extract_features
from mne.minimum_norm import apply_inverse_cov
import h5io
from utils import prepare_dataset
DATASETS = ['chbp', 'lemon', 'tuab', 'camcan']
FEATURE_TYPE = ['fb_covs', 'handcrafted', 'source_power']
parser = argparse.ArgumentParser(description='Compute features.')
parser.add_argument(
'-d', '--dataset',
default=None,
nargs='+',
help='the dataset for which features should be computed')
parser.add_argument(
'-t', '--feature_type',
default=None,
nargs='+', help='Type of features to compute')
parser.add_argument(
'--n_jobs', type=int, default=1,
help='number of parallel processes to use (default: 1)')
args = parser.parse_args()
datasets = args.dataset
feature_types = args.feature_type
n_jobs = args.n_jobs
if datasets is None:
datasets = list(DATASETS)
if feature_types is None:
feature_types = list(FEATURE_TYPE)
tasks = [(ds, bs) for ds in datasets for bs in feature_types]
for dataset, feature_type in tasks:
if dataset not in DATASETS:
raise ValueError(f"The dataset '{dataset}' passed is unkonwn")
if feature_type not in FEATURE_TYPE:
raise ValueError(f"The benchmark '{feature_type}' passed is unkonwn")
print(f"Running benchmarks: {', '.join(feature_types)}")
print(f"Datasets: {', '.join(datasets)}")
DEBUG = False
frequency_bands = {
"low": (0.1, 1),
"delta": (1, 4),
"theta": (4.0, 8.0),
"alpha": (8.0, 15.0),
"beta_low": (15.0, 26.0),
"beta_mid": (26.0, 35.0),
"beta_high": (35.0, 49)
}
hc_selected_funcs = [
'std',
'kurtosis',
'skewness',
'quantile',
'ptp_amp',
'mean',
'pow_freq_bands',
'spect_entropy',
'app_entropy',
'samp_entropy',
'svd_entropy',
'hurst_exp',
'hjorth_complexity',
'hjorth_mobility',
'line_length',
'wavelet_coef_energy',
'higuchi_fd',
'zero_crossings',
'svd_fisher_info'
]
hc_func_params = {
'quantile__q': [0.1, 0.25, 0.75, 0.9],
'pow_freq_bands__freq_bands': [0, 2, 4, 8, 13, 18, 24, 30, 49],
'pow_freq_bands__ratios': 'all',
'pow_freq_bands__ratios_triu': True,
'pow_freq_bands__log': True,
'pow_freq_bands__normalize': None,
}
def extract_fb_covs(epochs, condition):
features, meta_info = coffeine.compute_features(
epochs[condition], features=('covs',), n_fft=1024, n_overlap=512,
fs=epochs.info['sfreq'], fmax=49, frequency_bands=frequency_bands)
features['meta_info'] = meta_info
return features
def extract_handcrafted_feats(epochs, condition):
features = extract_features(
epochs[condition].get_data(), epochs.info['sfreq'], hc_selected_funcs,
funcs_params=hc_func_params, n_jobs=1, ch_names=epochs.ch_names,
return_as_df=False)
out = {'feats': features}
return out
def extract_source_power(bp, info, subject, subjects_dir, covs):
fname_inv = bp.copy().update(suffix='inv',
processing=None,
extension='.fif')
inv = mne.minimum_norm.read_inverse_operator(fname_inv)
# Prepare label time series
labels = mne.read_labels_from_annot(
'fsaverage_small' if dataset == 'camcan' else 'fsaverage',
'aparc_sub', subjects_dir=subjects_dir)
labels = [ll for ll in labels if 'unknown' not in ll.name]
# for each frequency band
result = []
for i in range(covs.shape[0]):
cov = mne.Covariance(data=covs[i, :, :],
names=info['ch_names'],
bads=info['bads'],
projs=info['projs'],
nfree=0) # nfree ?
stc = apply_inverse_cov(cov, info, inv,
nave=1,
method="dSPM")
label_power = mne.extract_label_time_course(stc,
labels,
inv['src'],
mode="mean")
result.append(np.diag(label_power[:,0]))
return result
def run_subject(subject, cfg, condition):
task = cfg.task
deriv_root = cfg.deriv_root
data_type = cfg.data_type
session = cfg.session
if session.startswith('ses-'):
session = session.lstrip('ses-')
bp_args = dict(root=deriv_root, subject=subject,
datatype=data_type, processing="autoreject",
task=task,
check=False, suffix="epo")
if session:
bp_args['session'] = session
bp = BIDSPath(**bp_args)
if not bp.fpath.exists():
return 'no file'
epochs = mne.read_epochs(bp, proj=False, preload=True)
if not any(condition in cc for cc in epochs.event_id):
return 'condition not found'
out = None
# make sure that no EOG/ECG made it into the selection
epochs.pick_types(**{data_type: True})
try:
if feature_type == 'fb_covs':
out = extract_fb_covs(epochs, condition)
elif feature_type == 'handcrafted':
out = extract_handcrafted_feats(epochs, condition)
elif feature_type == 'source_power':
covs = extract_fb_covs(epochs, condition)
covs = covs['covs']
out = extract_source_power(
bp, epochs.info, subject, cfg.subjects_dir, covs)
else:
NotImplementedError()
except Exception as err:
raise err
return repr(err)
return out
for dataset, feature_type in tasks:
cfg, subjects = prepare_dataset(dataset)
N_JOBS = cfg.N_JOBS if not n_jobs else n_jobs
if DEBUG:
subjects = subjects[:1]
N_JOBS = 1
frequency_bands = {"alpha": (8.0, 15.0)}
hc_selected_funcs = ['std']
hc_func_params = dict()
for condition in cfg.feature_conditions:
print(
f"Computing {feature_type} features on {dataset} for '{condition}'")
features = Parallel(n_jobs=N_JOBS)(
delayed(run_subject)(sub.split('-')[1], cfg=cfg,
condition=condition) for sub in subjects)
out = {sub: ff for sub, ff in zip(subjects, features)
if not isinstance(ff, str)}
label = None
if dataset in ("chbp", "lemon"):
label = 'pooled'
if '/' in condition:
label = f'eyes-{condition.split("/")[1]}'
elif dataset in ("tuab", 'camcan'):
label = 'rest'
out_fname = cfg.deriv_root / f'features_{feature_type}_{label}.h5'
log_out_fname = (
cfg.deriv_root / f'feature_{feature_type}_{label}-log.csv')
h5io.write_hdf5(
out_fname,
out,
overwrite=True
)
print(f'Features saved under {out_fname}.')
logging = ['OK' if not isinstance(ff, str) else ff for sub, ff in
zip(subjects, features)]
out_log = pd.DataFrame({"ok": logging, "subject": subjects})
out_log.to_csv(log_out_fname)
print(f'Log saved under {log_out_fname}.')