-
Notifications
You must be signed in to change notification settings - Fork 17
/
Copy pathresnet.py
129 lines (92 loc) · 4.11 KB
/
resnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import argparse
from keras.layers import (
Conv2D, BatchNormalization,
MaxPooling2D, ZeroPadding2D, AveragePooling2D,
add, Dense, Flatten
)
from keras.layers.advanced_activations import PReLU
from model import BaseModel
from utils import load_mnist
def resnet(input_tensor):
"""Inference function for ResNet
y = resnet(X)
Parameters
----------
input_tensor : keras.layers.Input
Returns
----------
y : softmax output
"""
def name_builder(type, stage, block, name):
return "{}{}{}_branch{}".format(type, stage, block, name)
def identity_block(input_tensor, kernel_size, filters, stage, block):
F1, F2, F3 = filters
def name_fn(type, name):
return name_builder(type, stage, block, name)
x = Conv2D(F1, (1, 1), name=name_fn('res', '2a'))(input_tensor)
x = BatchNormalization(name=name_fn('bn', '2a'))(x)
x = PReLU()(x)
x = Conv2D(F2, kernel_size, padding='same', name=name_fn('res', '2b'))(x)
x = BatchNormalization(name=name_fn('bn', '2b'))(x)
x = PReLU()(x)
x = Conv2D(F3, (1, 1), name=name_fn('res', '2c'))(x)
x = BatchNormalization(name=name_fn('bn', '2c'))(x)
x = PReLU()(x)
x = add([x, input_tensor])
x = PReLU()(x)
return x
def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
def name_fn(type, name):
return name_builder(type, stage, block, name)
F1, F2, F3 = filters
x = Conv2D(F1, (1, 1), strides=strides, name=name_fn("res", "2a"))(input_tensor)
x = BatchNormalization(name=name_fn("bn", "2a"))(x)
x = PReLU()(x)
x = Conv2D(F2, kernel_size, padding='same', name=name_fn("res", "2b"))(x)
x = BatchNormalization(name=name_fn("bn", "2b"))(x)
x = PReLU()(x)
x = Conv2D(F3, (1, 1), name=name_fn("res", "2c"))(x)
x = BatchNormalization(name=name_fn("bn", "2c"))(x)
sc = Conv2D(F3, (1, 1), strides=strides, name=name_fn("res", "1"))(input_tensor)
sc = BatchNormalization(name=name_fn("bn", "1"))(sc)
x = add([x, sc])
x = PReLU()(x)
return x
net = ZeroPadding2D((3, 3))(input_tensor)
net = Conv2D(64, (7, 7), strides=(2, 2), name="conv1")(net)
net = BatchNormalization(name="bn_conv1")(net)
net = PReLU()(net)
net = MaxPooling2D((3, 3), strides=(2, 2))(net)
net = conv_block(net, 3, [64, 64, 256], stage=2, block='a', strides=(1, 1))
net = identity_block(net, 3, [64, 64, 256], stage=2, block='b')
net = identity_block(net, 3, [64, 64, 256], stage=2, block='c')
net = conv_block(net, 3, [128, 128, 512], stage=3, block='a')
net = identity_block(net, 3, [128, 128, 512], stage=3, block='b')
net = identity_block(net, 3, [128, 128, 512], stage=3, block='c')
net = identity_block(net, 3, [128, 128, 512], stage=3, block='d')
net = conv_block(net, 3, [256, 256, 1024], stage=4, block='a')
net = identity_block(net, 3, [256, 256, 1024], stage=4, block='b')
net = identity_block(net, 3, [256, 256, 1024], stage=4, block='c')
net = identity_block(net, 3, [256, 256, 1024], stage=4, block='d')
net = identity_block(net, 3, [256, 256, 1024], stage=4, block='e')
net = identity_block(net, 3, [256, 256, 1024], stage=4, block='f')
net = AveragePooling2D((2, 2))(net)
net = Flatten()(net)
net = Dense(10, activation="softmax", name="softmax")(net)
return net
class ResNet50(BaseModel):
def __init__(self, model_path):
super(ResNet50, self).__init__("resnet", resnet, model_path)
def arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("epoch", type=int, help="Epochs")
parser.add_argument("--model_path", default="model/resnet.h5", type=str, help="model path (default: model/resnet.h5)")
args = parser.parse_args()
return args.epoch, args.model_path
def main():
EPOCH, MODEL_PATH = arg_parser()
train, valid, _ = load_mnist(samplewise_normalize=True)
model = ResNet50(MODEL_PATH)
model.fit((train[0], train[1]), (valid[0], valid[1]), EPOCH)
if __name__ == '__main__':
main()