forked from dagophil/autocontext
-
Notifications
You must be signed in to change notification settings - Fork 0
/
autocontext.py
436 lines (370 loc) · 18.6 KB
/
autocontext.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
"""
DESCRIPTION
How to use:
* Start a new ilastik pixel classification project and add one or more datasets.
* Select some features.
* Add some labels.
* Save project and exit ilastik.
* Run this script (parameters: see command line arguments from argparse) or use the autocontext function.
"""
import argparse
import glob
import os
import random
import shutil
import subprocess
import sys
import colorama as col
import vigra
from core.ilp import ILP
from core.ilp import merge_datasets, reshape_tzyxc
from core.labels import scatter_labels
from core.ilp_constants import default_export_key
def autocontext(ilastik_cmd, project, runs, label_data_nr, weights=None, predict_file=False):
"""Trains and predicts the ilastik project using the autocontext method.
The parameter weights can be used to take different amounts of the labels in each loop run.
Example: runs = 3, weights = [3, 2, 1]
The sum of the weights is 6, so in the first run, 1/2 (== 3/6) of the labels is used,
then 1/3 (== 2/6), then 1/6.
If weights is None, the labels are equally distributed over the loop runs.
:param ilastik_cmd: path to run_ilastik.sh
:param project: the ILP object of the project
:param runs: number of runs of the autocontet loop
:param label_data_nr: number of dataset that contains the labels (-1: use all datasets)
:param weights: weights for the labels
:param predict_file: if this is True, the --predict_file option of ilastik is used
"""
assert isinstance(project, ILP)
# Create weights if none were given.
if weights is None:
weights = [1]*runs
if len(weights) < runs:
raise Exception("The number of weights must not be smaller than the number of runs.")
weights = weights[:runs]
# Copy the raw data to the output folder and reshape it to txyzc.
project.extend_data_tzyxc()
# Get the number of datasets.
data_count = project.data_count
# Get the current number of channels in the datasets.
# The data in those channels is left unchanged when the ilastik output is merged back.
keep_channels = [project.get_channel_count(i) for i in range(data_count)]
# Read the labels from the first block and split them into parts, so not all labels are used in each loop.
label_count = len(project.label_names)
if label_data_nr == -1:
blocks_with_slicing = [(i, project.get_labels(i)) for i in xrange(project.labelsets_count)]
else:
blocks_with_slicing = [(label_data_nr, project.get_labels(label_data_nr))]
scattered_labels_list = [scatter_labels(blocks, label_count, runs, weights)
for i, (blocks, block_slices) in blocks_with_slicing]
# Do the autocontext loop.
for i in range(runs):
print col.Fore.GREEN + "- Running autocontext training round %d of %d -" % (i+1, runs) + col.Fore.RESET
# Insert the subset of the labels into the project.
for (k, (blocks, block_slices)), scattered_labels in zip(blocks_with_slicing, scattered_labels_list):
split_blocks = scattered_labels[i]
project.replace_labels(k, split_blocks, block_slices)
# Retrain the project.
print col.Fore.GREEN + "Retraining:" + col.Fore.RESET
project.retrain(ilastik_cmd)
# Save the project so it can be used in the batch prediction.
filename = "rf_" + str(i).zfill(len(str(runs-1))) + ".ilp"
filename = os.path.join(project.cache_folder, filename)
print col.Fore.GREEN + "Saving the project to " + filename + col.Fore.RESET
project.save(filename, remove_labels=True, remove_internal_data=True)
# Predict all datasets.
print col.Fore.GREEN + "Predicting all datasets:" + col.Fore.RESET
project.predict_all_datasets(ilastik_cmd, predict_file=predict_file)
# Merge the probabilities back into the datasets.
print col.Fore.GREEN + "Merging output back into datasets." + col.Fore.RESET
for k in range(data_count):
project.merge_output_into_dataset(k, keep_channels[k])
# Insert the original labels back into the project.
for k, (blocks, block_slices) in blocks_with_slicing:
project.replace_labels(k, blocks, block_slices)
def autocontext_forests(dirname):
"""Open the ilastik random forests from the given trained autocontext.
:param dirname: autocontext cache folder
:return: list with ilastik random forest filenames
"""
rf_files = []
for filename in os.listdir(dirname):
fullname = os.path.join(dirname, filename)
if os.path.isfile(fullname) and len(filename) >= 8:
base, middle, end = filename[:3], filename[3:-4], filename[-4:]
if base == "rf_" and end ==".ilp":
rf_files.append((int(middle), fullname))
rf_files = sorted(rf_files)
rf_indices, rf_files = zip(*rf_files)
assert rf_indices == tuple(xrange(len(rf_files))) # check that there are only the indices 0, 1, 2, ... .
return rf_files
def batch_predict(args, ilastik_args):
"""Do the batch prediction.
:param args: command line arguments
:param ilastik_args: additional ilastik arguments
"""
# Create the folder for the intermediate results.
if not os.path.isdir(args.cache):
os.makedirs(args.cache)
# Find the random forest files.
rf_files = autocontext_forests(args.batch_predict)
n = len(rf_files)
# Get the output format arguments.
default_output_format = "hdf5"
default_output_filename_format = os.path.join(args.cache, "{nickname}_probs.h5")
ilastik_parser = argparse.ArgumentParser()
ilastik_parser.add_argument("--output_format", type=str, default=default_output_format)
ilastik_parser.add_argument("--output_filename_format", type=str, default=default_output_filename_format)
ilastik_parser.add_argument("--output_internal_path", type=str, default=default_export_key())
format_args, ilastik_args = ilastik_parser.parse_known_args(ilastik_args)
output_formats = [default_output_format] * (n-1) + [format_args.output_format]
if args.no_overwrite:
output_filename_formats = [default_output_filename_format[:-3] + "_%s" % str(i).zfill(2) + default_output_filename_format[-3:] for i in xrange(n-1)] + [format_args.output_filename_format]
else:
output_filename_formats = [default_output_filename_format] * (n-1) + [format_args.output_filename_format]
output_internal_paths = [default_export_key()] * (n-1) + [format_args.output_internal_path]
# Reshape the data to tzyxc and move it to the cache folder.
outfiles = []
keep_channels = None
for i in xrange(len(args.files)):
# Read the data and attach axistags.
filename = args.files[i]
if ".h5/" in filename or ".hdf5/" in filename:
data_key = os.path.basename(filename)
data_path = filename[:-len(data_key)-1]
data = vigra.readHDF5(data_path, data_key)
else:
data_key = default_export_key()
data_path_base, data_path_ext = os.path.splitext(filename)
data_path = data_path_base + ".h5"
data = vigra.readImage(filename)
if not hasattr(data, "axistags"):
default_tags = {1: "x",
2: "xy",
3: "xyz",
4: "xyzc",
5: "txyzc"}
data = vigra.VigraArray(data, axistags=vigra.defaultAxistags(default_tags[len(data.shape)]),
dtype=data.dtype)
new_data = reshape_tzyxc(data)
if i == 0:
c_index = new_data.axistags.index("c")
keep_channels = new_data.shape[c_index]
# Save the reshaped dataset.
output_filename = os.path.split(data_path)[1]
output_filename = os.path.join(args.cache, output_filename)
vigra.writeHDF5(new_data, output_filename, data_key, compression=args.compression)
args.files[i] = output_filename + "/" + data_key
if args.no_overwrite:
outfiles.append([os.path.splitext(output_filename)[0] + "_probs_%s.h5" % str(i).zfill(2) for i in xrange(n-1)])
else:
outfiles.append([os.path.splitext(output_filename)[0] + "_probs.h5"] * (n-1))
assert keep_channels > 0
# Run the batch prediction.
for i in xrange(n):
rf_file = rf_files[i]
output_format = output_formats[i]
output_filename_format = output_filename_formats[i]
output_internal_path = output_internal_paths[i]
filename_key = os.path.basename(args.files[0])
filename_path = args.files[0][:-len(filename_key)-1]
# Quick hack to prevent the ilastik error "wrong number of channels".
p = ILP(rf_file, args.cache, compression=args.compression)
for j in xrange(p.data_count):
p.set_data_path_key(j, filename_path, filename_key)
# Call ilastik to run the batch prediction.
cmd = [args.ilastik,
"--headless",
"--project=%s" % rf_file,
"--output_format=%s" % output_format,
"--output_filename_format=%s" % output_filename_format,
"--output_internal_path=%s" % output_internal_path]
if args.predict_file:
pfile = os.path.join(args.cache, "predict_file.txt")
with open(pfile, "w") as f:
for pf in args.files:
f.write(os.path.abspath(pf) + "\n")
cmd.append("--predict_file=%s" % pfile)
else:
cmd += args.files
print col.Fore.GREEN + "- Running autocontext batch prediction round %d of %d -" % (i+1, n) + col.Fore.RESET
subprocess.call(cmd, stdout=sys.stdout)
if i < n-1:
# Merge the probabilities back to the original file.
for filename, filename_out in zip(args.files, outfiles):
filename_key = os.path.basename(filename)
filename_path = filename[:-len(filename_key)-1]
merge_datasets(filename_path, filename_key, filename_out[i], output_internal_path, n=keep_channels,
compression=args.compression)
def train(args):
"""Do the autocontext training.
:param args: command line arguments
"""
# Copy the project file.
# TODO: If the file exists, ask the user if it shall be deleted.
if os.path.isfile(args.outfile):
os.remove(args.outfile)
shutil.copyfile(args.train, args.outfile)
# Create an ILP object for the project.
proj = ILP(args.outfile, args.cache, args.compression)
# Do the autocontext loop.
autocontext(args.ilastik, proj, args.nloops, args.labeldataset, weights=args.weights, predict_file=args.predict_file)
def process_command_line():
"""Parse command line arguments.
"""
# Add the command line arguments.
parser = argparse.ArgumentParser(description="ilastik autocontext",
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# General arguments.
parser.add_argument("--ilastik", type=str, required=True,
help="path to the file run_ilastik.sh")
parser.add_argument("--predict_file", action="store_true",
help="add this flag if ilastik supports the --predict_file option")
parser.add_argument("-c", "--cache", type=str, default="cache",
help="name of the cache folder")
parser.add_argument("--compression", default="lzf", type=str, choices=["lzf", "gzip", "szip", "None"],
help="compression filter for the hdf5 files")
parser.add_argument("--clear_cache", action="store_true",
help="clear the cache folder without asking")
parser.add_argument("--keep_cache", action="store_true",
help="keep the cache folder without asking")
# Training arguments.
parser.add_argument("--train", type=str,
help="path to the ilastik project that will be used for training")
parser.add_argument("-o", "--outfile", type=str, default="",
help="output file")
parser.add_argument("-n", "--nloops", type=int, default=3,
help="number of autocontext loop iterations")
parser.add_argument("-d", "--labeldataset", type=int, default=-1,
help="id of dataset in the ilp file that contains the labels (-1: use all datasets)")
parser.add_argument("--seed", type=int, default=None,
help="the random seed")
parser.add_argument("--weights", type=float, nargs="*", default=[],
help="amount of labels that are used in each round")
# Batch prediction arguments.
parser.add_argument("--batch_predict", type=str,
help="path of the cache folder of a previously trained autocontext that will be used for batch "
"prediction")
parser.add_argument("--files", type=str, nargs="+",
help="the files for the batch prediction")
parser.add_argument("--no_overwrite", action="store_true",
help="create one _probs file for each autocontext iteration in the batch prediction")
# Do the parsing.
args, ilastik_args = parser.parse_known_args()
# Expand the filenames.
args.ilastik = os.path.expanduser(args.ilastik)
args.cache = os.path.expanduser(args.cache)
if args.train is not None:
args.train = os.path.expanduser(args.train)
args.outfile = os.path.expanduser(args.outfile)
if args.batch_predict is not None:
args.batch_predict = os.path.expanduser(args.batch_predict)
# Check if ilastik is an executable.
if not os.path.isfile(args.ilastik) or not os.access(args.ilastik, os.X_OK):
raise Exception("%s is not an executable file." % args.ilastik)
# Check that only one of the options --clear_cache, --keep_cache was set.
if args.clear_cache and args.keep_cache:
raise Exception("--clear_cache and --keep_cache must not be combined.")
# Check for conflicts between training and batch prediction arguments.
if args.train is None and args.batch_predict is None:
raise Exception("One of the arguments --train or --batch_predict must be given.")
if args.train is not None and args.batch_predict is not None:
raise Exception("--train and --batch_predict must not be combined.")
# Check if the training arguments are valid.
if args.train:
if len(ilastik_args) > 0:
raise Exception("The training does not accept unknown arguments: %s" % ilastik_args)
if args.files is not None:
raise Exception("--train cannot be used for batch prediction.")
if not os.path.isfile(args.train):
raise Exception("%s is not a file." % args.train)
if len(args.outfile) == 0:
file_path, file_ext = os.path.splitext(args.train)
args.outfile = file_path + "_out" + file_ext
if args.labeldataset < -1:
raise Exception("Wrong id of label dataset: %d" % args.d)
if args.compression == "None":
args.compression = None
if len(args.weights) == 0:
args.weights = None
if args.weights is not None and len(args.weights) != args.nloops:
raise Exception("Number of weights must be equal to number of autocontext iterations.")
# Check if the batch prediction arguments are valid.
if args.batch_predict:
if os.path.normpath(os.path.abspath(args.batch_predict)) == os.path.normpath(os.path.abspath(args.cache)):
raise Exception("The --batch_predict and --cache directories must be different.")
if args.files is None:
raise Exception("Tried to use batch prediction without --files.")
if not os.path.isdir(args.batch_predict):
raise Exception("%s is not a directory." % args.batch_predict)
# Expand filenames that include *.
expanded_files = [os.path.expanduser(f) for f in args.files]
args.files = []
for filename in expanded_files:
if "*" in filename:
if ".h5/" in filename or ".hdf5/" in filename:
if ".h5/" in filename:
i = filename.index(".h5")
filename_path = filename[:i+3]
filename_key = filename[i+4:]
else:
i = filename.index(".hdf5")
filename_path = filename[:i+5]
filename_key = filename[i+6:]
to_append = glob.glob(filename_path)
to_append = [f + "/" + filename_key for f in to_append]
args.files += to_append
else:
args.files += glob.glob(filename)
else:
args.files.append(filename)
# Remove the --headless, --project and --output_internal_path arguments.
ilastik_parser = argparse.ArgumentParser()
ilastik_parser.add_argument("--headless", action="store_true")
ilastik_parser.add_argument("--project", type=str)
ilastik_args = ilastik_parser.parse_known_args(ilastik_args)[1]
return args, ilastik_args
def main():
"""
"""
# Read command line arguments.
args, ilastik_args = process_command_line()
# Initialize colorama and random seeds.
random.seed(args.seed)
col.init()
# Clear the cache folder.
if os.path.isdir(args.cache):
print "The cache folder", os.path.abspath(args.cache), "already exists."
clear_cache = False
if args.clear_cache:
print "The option --clear_cache was set, so the cache folder will be cleared."
clear_cache = True
elif args.keep_cache:
print "The option --keep_cache was set, so the cache folder will not be cleared."
else:
cc_input = raw_input("Clear cache folder? [y|n] : ")
if cc_input in ["y", "Y"]:
clear_cache = True
if clear_cache:
for f in os.listdir(args.cache):
f_path = os.path.join(args.cache, f)
try:
if os.path.isfile(f_path):
os.remove(f_path)
elif os.path.isdir(f_path):
shutil.rmtree(f_path)
except Exception, e:
print e
print "Cleared cache folder."
else:
print "Cache folder not cleared."
if args.train:
# Do the autocontext training.
train(args)
else:
# Do the batch prediction.
assert args.batch_predict
batch_predict(args, ilastik_args)
return 0
if __name__ == "__main__":
status = main()
sys.exit(status)