-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathmain.py
164 lines (143 loc) · 5.74 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
#!/usr/bin/env python3
# -*- coding: UTF-8 -*-
import torch
import os
import argparse
import torch.optim as optim
from crnn import CRNN
from utils import *
def train(root, start_epoch, epoch_num, letters,
net=None, lr=0.1, fix_width=True):
"""
Train CRNN model
Args:
root (str): Root directory of dataset
start_epoch (int): Epoch number to start
epoch_num (int): Epoch number to train
letters (str): Letters contained in the data
net (CRNN, optional): CRNN model (default: None)
lr (float, optional): Coefficient that scale delta before it is applied
to the parameters (default: 1.0)
fix_width (bool, optional): Scale images to fixed size (default: True)
Returns:
CRNN: Trained CRNN model
"""
# load data
trainloader = load_data(root, training=True, fix_width=fix_width)
if not net:
# create a new model if net is None
net = CRNN(1, len(letters) + 1)
criterion = torch.nn.CTCLoss()
optimizer = optim.Adadelta(net.parameters(), lr=lr, weight_decay=1e-3)
# use gpu or not
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
if use_cuda:
net = net.to(device)
criterion = criterion.to(device)
else:
print("***** Warning: Cuda isn't available! *****")
# get encoder and decoder
labeltransformer = LabelTransformer(letters)
print('==== Training.. ====')
# .train() has any effect on Dropout and BatchNorm.
net.train()
for epoch in range(start_epoch, start_epoch + epoch_num):
print('---- epoch: %d ----' % (epoch, ))
loss_sum = 0
for i, (img, label) in enumerate(trainloader):
label, label_length = labeltransformer.encode(label)
img = img.to(device)
optimizer.zero_grad()
# put images in
outputs = net(img)
output_length = torch.IntTensor(
[outputs.size(0)]*outputs.size(1))
# calc loss
loss = criterion(outputs, label, output_length, label_length)
# update
loss.backward()
optimizer.step()
loss_sum += loss.item()
print('loss = %f' % (loss_sum, ))
print('Finished Training')
return net
def test(root, net, letters, fix_width=True):
"""
Test CRNN model
Args:
root (str): Root directory of dataset
letters (str): Letters contained in the data
net (CRNN, optional): trained CRNN model
fix_width (bool, optional): Scale images to fixed size (default: True)
"""
# load data
trainloader = load_data(root, training=True, fix_width=fix_width)
testloader = load_data(root, training=False, fix_width=fix_width)
# use gpu or not
use_cuda = torch.cuda.is_available()
device = torch.device('cuda' if use_cuda else 'cpu')
if use_cuda:
net = net.to(device)
else:
print("***** Warning: Cuda isn't available! *****")
# get encoder and decoder
labeltransformer = LabelTransformer(letters)
print('==== Testing.. ====')
# .eval() has any effect on Dropout and BatchNorm.
net.eval()
acc = []
for loader in (testloader, trainloader):
correct = 0
total = 0
for i, (img, origin_label) in enumerate(loader):
img = img.to(device)
outputs = net(img) # length × batch × num_letters
outputs = outputs.max(2)[1].transpose(0, 1) # batch × length
outputs = labeltransformer.decode(outputs.data)
correct += sum([out == real for out,
real in zip(outputs, origin_label)])
total += len(origin_label)
# calc accuracy
acc.append(correct / total * 100)
print('testing accuracy: ', acc[0], '%')
print('training accuracy: ', acc[1], '%')
def main(epoch_num, lr=0.1, training=True, fix_width=True):
"""
Main
Args:
training (bool, optional): If True, train the model, otherwise test it (default: True)
fix_width (bool, optional): Scale images to fixed size (default: True)
"""
model_path = ('fix_width_' if fix_width else '') + 'crnn.pth'
letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789'
root = 'data/IIIT5K/'
if training:
net = CRNN(1, len(letters) + 1)
start_epoch = 0
# if there is pre-trained model, load it
if os.path.exists(model_path):
print('Pre-trained model detected.\nLoading model...')
net.load_state_dict(torch.load(model_path))
if torch.cuda.is_available():
print('GPU detected.')
net = train(root, start_epoch, epoch_num, letters,
net=net, lr=lr, fix_width=fix_width)
# save the trained model for training again
torch.save(net.state_dict(), model_path)
# test
test(root, net, letters, fix_width=fix_width)
else:
net = CRNN(1, len(letters) + 1)
if os.path.exists(model_path):
net.load_state_dict(torch.load(model_path))
test(root, net, letters, fix_width=fix_width)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--epoch_num', type=int, default=50, help='number of epochs to train for (default=20)')
parser.add_argument('--lr', type=float, default=0.1, help='learning rate for optim (default=0.1)')
parser.add_argument('--test', action='store_true', help='Whether to test directly (default is training)')
parser.add_argument('--fix_width', action='store_true', help='Whether to resize images to the fixed width (default is False)')
opt = parser.parse_args()
print(opt)
main(opt.epoch_num, lr=opt.lr, training=(not opt.test), fix_width=opt.fix_width)