-
Notifications
You must be signed in to change notification settings - Fork 60
/
ctc_predict.py
60 lines (48 loc) · 2.01 KB
/
ctc_predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import argparse
import tensorflow as tf
import ctc_utils
import cv2
import numpy as np
parser = argparse.ArgumentParser(description='Decode a music score image with a trained model (CTC).')
parser.add_argument('-image', dest='image', type=str, required=True, help='Path to the input image.')
parser.add_argument('-model', dest='model', type=str, required=True, help='Path to the trained model.')
parser.add_argument('-vocabulary', dest='voc_file', type=str, required=True, help='Path to the vocabulary file.')
args = parser.parse_args()
tf.reset_default_graph()
sess = tf.InteractiveSession()
# Read the dictionary
dict_file = open(args.voc_file,'r')
dict_list = dict_file.read().splitlines()
int2word = dict()
for word in dict_list:
word_idx = len(int2word)
int2word[word_idx] = word
dict_file.close()
# Restore weights
saver = tf.train.import_meta_graph(args.model)
saver.restore(sess,args.model[:-5])
graph = tf.get_default_graph()
input = graph.get_tensor_by_name("model_input:0")
seq_len = graph.get_tensor_by_name("seq_lengths:0")
rnn_keep_prob = graph.get_tensor_by_name("keep_prob:0")
height_tensor = graph.get_tensor_by_name("input_height:0")
width_reduction_tensor = graph.get_tensor_by_name("width_reduction:0")
logits = tf.get_collection("logits")[0]
# Constants that are saved inside the model itself
WIDTH_REDUCTION, HEIGHT = sess.run([width_reduction_tensor, height_tensor])
decoded, _ = tf.nn.ctc_greedy_decoder(logits, seq_len)
image = cv2.imread(args.image,False)
image = ctc_utils.resize(image, HEIGHT)
image = ctc_utils.normalize(image)
image = np.asarray(image).reshape(1,image.shape[0],image.shape[1],1)
seq_lengths = [ image.shape[2] / WIDTH_REDUCTION ]
prediction = sess.run(decoded,
feed_dict={
input: image,
seq_len: seq_lengths,
rnn_keep_prob: 1.0,
})
str_predictions = ctc_utils.sparse_tensor_to_strs(prediction)
for w in str_predictions[0]:
print (int2word[w]),
print ('\t'),