-
Notifications
You must be signed in to change notification settings - Fork 6
/
utils.py
141 lines (112 loc) · 4.26 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import os
import pickle
import string
import numpy as np
try:
import requests
has_requests = True
except ImportError:
has_requests = False
# permissible transcriptions for computing accuracy stats
allowed_text = ['loha', 'alha', 'aloa', 'aloh', 'aoha', 'aloha']
id_to_char = np.array([x for x in string.ascii_lowercase + '\' -'])
def merge(chars):
'''Merge repeated characters and strip blank CTC symbol'''
acc = ["-"]
for c in chars:
if c != acc[-1]:
acc.append(c)
acc = [c for c in acc if c != "-"]
return "".join(acc)
def weight_init(shape):
'''Convenience function for randomly initializing weights'''
weights = np.random.uniform(-0.05, 0.05, size=shape)
return weights
def predict_text(sim, char_probe, n_steps, p_time):
'''Predict a text transcription from the current simulation state'''
n_frames = int(n_steps / p_time)
char_data = sim.data[char_probe]
n_chars = char_data.shape[1]
# reshape to seperate out each window frame that was presented
char_out = np.reshape(char_data, (n_frames, p_time, n_chars))
# take most ofter predicted char over each frame presentation interval
char_ids = np.argmax(char_out, axis=2)
char_ids = [np.argmax(np.bincount(i)) for i in char_ids]
text = merge(''.join([id_to_char[i] for i in char_ids]))
text = merge(text) # merge repeats to help autocorrect
return text
def download(fname, drive_id):
'''Download a file from Google Drive.
Adapted from https://stackoverflow.com/a/39225039/1306923
'''
def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value
return None
def save_response_content(response, destination):
CHUNK_SIZE = 32768
with open(destination, "wb") as f:
for chunk in response.iter_content(CHUNK_SIZE):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
url = "https://docs.google.com/uc?export=download"
session = requests.Session()
response = session.get(url, params={'id': drive_id}, stream=True)
token = get_confirm_token(response)
if token is not None:
params = {'id': drive_id, 'confirm': token}
response = session.get(url, params=params, stream=True)
save_response_content(response, fname)
def load(fname, drive_id):
'''Load file either by downloading or using already downloaded version'''
if not os.path.exists(fname):
if has_requests:
print("Downloading %s..." % fname)
download(fname, drive_id)
print("Saved %s to %s" % (fname, os.getcwd()))
else:
link = "https://drive.google.com/open?id=%s" % drive_id
raise RuntimeError(
"Cannot find '%s'. Download the file from\n %s\n"
"and place it in %s." % (fname, link, os.getcwd()))
print("Loading %s" % fname)
with open(fname, "rb") as fp:
ret = pickle.load(fp)
return ret
def compute_tf_stats(model, data):
'''Compute True/False Pos/Neg stats for Tensorflow keyword model'''
stats = {
"fp":0,
"tp":0,
"fn":0,
"tn":0,
"aloha": 0,
"not-aloha": 0
}
for features, text in data:
inputs = np.squeeze(features)
chars = []
for window in inputs:
char = model.predict_text(np.expand_dims(window, axis=0))
chars.append(char)
predicted_chars = model.merge(model.merge(''.join(chars)))
if text == 'aloha':
stats["aloha"] += 1
if predicted_chars in allowed_text:
stats["tp"] += 1
else:
stats["fn"] += 1
else:
stats["not-aloha"] += 1
if predicted_chars in allowed_text:
stats["fp"] += 1
else:
stats["tn"] += 1
print("Summary")
print("=======")
print("True positive rate:\t%.3f" % (stats["tp"] / stats["aloha"]))
print("False negative rate:\t%.3f" % (stats["fn"] / stats["aloha"]))
print()
print("True negative rate:\t%.3f" % (stats["tn"] / stats["not-aloha"]))
print("False positive rate:\t%.3f" % (stats["fp"] / stats["not-aloha"]))