-
Notifications
You must be signed in to change notification settings - Fork 0
/
deep_unet_test_elli.py
73 lines (57 loc) · 2.49 KB
/
deep_unet_test_elli.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import model.deep_unet_decoder as dec
from model.encoder import sepresnet152v2_c
from utility.loss import seg_loss, dice_coef, iou_coef
import os
import glob
import numpy as np
import tensorflow as tf
from tensorflow.keras.optimizers import Adam, SGD
from tensorflow.keras.callbacks import ModelCheckpoint
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
def npy_loader(maindir, seed: int = 2):
# Get list of files in directory
directory = maindir + '*.npy'
pathlist = glob.glob(directory)
# Iterate over list of files
for path in pathlist:
array = np.load(path)
img = array[:, :, 0]
img = (img * 2) - 1
mask = array[:, :, 1]
yield img[..., np.newaxis], mask[..., np.newaxis]
def npy_dataset(maindir, shape_i, shape_m, seed: int = 2, batch: int = 1):
ds = tf.data.Dataset.from_generator(lambda: npy_loader(maindir=maindir, seed=seed),
output_types=(tf.float32, tf.float32),
output_shapes=(shape_i, shape_m))
return ds.batch(batch)
batch_size = 16
train_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm_elli/train/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
val_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm_elli/val/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
test_df = npy_dataset('/home/careinfolab/unet_mammo/images/pos_norm_elli/test/',
(1024, 832, 1),
(1024, 832, 1),
seed=2,
batch=batch_size)
enc = sepresnet152v2_c.SepResNet152V2((1024, 832, 1), 1)
encoder = enc.build_graph()
encoder.load_weights('./saved_models/sepresnet152v2_enc_c2')
# encoder.trainable = False
# encoder.summary()
decoder = dec.DeepUNet()
model = decoder.build_graph(encoder.input,
encoder.layers[-2].output,
encoder.layers[3].output)
model.compile(optimizer=Adam(learning_rate=0.0001), loss=[seg_loss], metrics=[dice_coef, iou_coef])
name = './saved_models/deep_unet_test152_elli'
model_checkpoint = ModelCheckpoint(name, monitor='val_loss', save_best_only=True)
history = model.fit(train_df, epochs=150, verbose=1, shuffle=True,
validation_data=val_df,
callbacks=[model_checkpoint])