-
Notifications
You must be signed in to change notification settings - Fork 5
/
train_predict.py
123 lines (100 loc) · 4.93 KB
/
train_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
import numpy as np
np.random.seed(1)
import random
random.seed(1)
import os
import tensorflow as tf
tf.set_random_seed(1)
from keras.callbacks import ReduceLROnPlateau, ModelCheckpoint, EarlyStopping, TensorBoard
from data import build_train_val, trainvalGenerator, testGenerator, save_result
from model import unet, unet_dilated
from losses import dice_loss
from mask_to_submission import make_submission
NUM_EPOCH = 100
NUM_TRAINING_STEP = 1000
NUM_VALIDATION_STEP = 80
TEST_SIZE = 50
# paths
train_path = os.path.join("data", "training")
val_path = os.path.join("data", "validation")
test_path = os.path.join("data", "test_set_images")
predict_path = "predict_images"
submission_path = "submission"
weight_path = "weights"
if not os.path.exists(val_path):
print("Build training and validation data set...")
build_train_val(train_path, val_path, val_size=0.2)
else:
print("Have found training and validation data set...")
print("Create generator for training and validation...")
# Arguments for data augmentation
data_gen_args = dict(rotation_range=45,
width_shift_range=0.1,
height_shift_range=0.1,
horizontal_flip=True,
vertical_flip=True,
fill_mode='reflect')
# Build generator for training and validation set
trainGen, valGen = trainvalGenerator(batch_size=2, aug_dict=data_gen_args,
train_path=train_path, val_path=val_path,
image_folder='images', mask_folder='groundtruth',
train_dir = None, # Set it to None if you don't want to save
val_dir = None, # Set it to None if you don't want to save
target_size = (400, 400), seed = 1)
print("Build model and training...")
print("...Build & train the modified U-Net with 32 filters...")
# Build model
model_32 = unet(n_filter=32, activation='elu', dropout_rate=0.2, loss=dice_loss)
# Callback functions
callbacks = [
# EarlyStopping(monitor='val_loss', patience=9, verbose=1, min_delta=1e-4),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1, min_delta=1e-4),
ModelCheckpoint(os.path.join(weight_path, 'weights_32.h5'), monitor='val_loss', save_best_only=True, verbose=1)
]
# Training
history_32 = model_32.fit_generator(generator=trainGen, steps_per_epoch=NUM_TRAINING_STEP,
validation_data=valGen, validation_steps=NUM_VALIDATION_STEP,
epochs=NUM_EPOCH, callbacks=callbacks)
print("...Build & train the modified U-Net with 64 filters...")
# Build model
model_64 = unet(n_filter=64, activation='elu', dropout_rate=0.2, loss=dice_loss)
# Callback functions
callbacks = [
# EarlyStopping(monitor='val_loss', patience=9, verbose=1, min_delta=1e-4),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1, min_delta=1e-4),
ModelCheckpoint(os.path.join(weight_path, 'weights_64.h5'), monitor='val_loss', save_best_only=True, verbose=1)
]
# Training
history_64 = model_64.fit_generator(generator=trainGen, steps_per_epoch=NUM_TRAINING_STEP,
validation_data=valGen, validation_steps=NUM_VALIDATION_STEP,
epochs=NUM_EPOCH, callbacks=callbacks)
print("...Build & train the U-Net with dilated convolution...")
# Build model
model_dilated = unet_dilated(n_filter=32, activation='elu', loss=dice_loss, dropout=False, batchnorm=False)
# Callback functions
callbacks = [
# EarlyStopping(monitor='val_loss', patience=9, verbose=1, min_delta=1e-4),
ReduceLROnPlateau(monitor='val_loss', factor=0.1, patience=3, verbose=1, min_delta=1e-4),
ModelCheckpoint(os.path.join(weight_path, 'weights_dilated.h5'), monitor='val_loss', save_best_only=True, verbose=1)
]
# Training
history_dilated = model_dilated.fit_generator(generator=trainGen, steps_per_epoch=NUM_TRAINING_STEP,
validation_data=valGen, validation_steps=NUM_VALIDATION_STEP,
epochs=NUM_EPOCH, callbacks=callbacks)
print("Predict and save results...")
print("...For U-Net with 32 filters...")
testGene = testGenerator(test_path)
result_1 = model_32.predict_generator(testGene, TEST_SIZE, verbose=1)
print("...For U-Net with 64 filters...")
testGene = testGenerator(test_path)
result_2 = model_64.predict_generator(testGene, TEST_SIZE, verbose=1)
print("...For U-Net with dilated convolution...")
testGene = testGenerator(test_path)
result_3 = model_dilated.predict_generator(testGene, TEST_SIZE, verbose=1)
print("...Averaging the prediction results...")
result = (result_1 + result_2 + result_3)/3
save_result(predict_path, result)
print("Make submission...")
make_submission(predict_path, test_size=TEST_SIZE, submission_filename=os.path.join(submission_path, "submission.csv"))
print("Done!")