-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
348 lines (323 loc) · 13.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# -*- coding:utf-8 -*-
import os
import cv2
import sys
import time
import copy
import random
import thread
import numpy as np
from mnist import mnist
import deepwell as dp
from PIL import Image
#设置np打印数组显示的门限
np.set_printoptions(threshold=np.nan)
#切换进程执行目录
os.chdir(sys.path[0])
lable_to_vector = [
[1.,0.,0.,0.,0.,0.,0.,0.,0.,0.],
[0.,1.,0.,0.,0.,0.,0.,0.,0.,0.],
[0.,0.,1.,0.,0.,0.,0.,0.,0.,0.],
[0.,0.,0.,1.,0.,0.,0.,0.,0.,0.],
[0.,0.,0.,0.,1.,0.,0.,0.,0.,0.],
[0.,0.,0.,0.,0.,1.,0.,0.,0.,0.],
[0.,0.,0.,0.,0.,0.,1.,0.,0.,0.],
[0.,0.,0.,0.,0.,0.,0.,1.,0.,0.],
[0.,0.,0.,0.,0.,0.,0.,0.,1.,0.],
[0.,0.,0.,0.,0.,0.,0.,0.,0.,1.]]
width = 640#摄像头尺寸
hight = 480
status = 0#进程状态,默认0 ideal,-1 退出,3 保存feature,并重新
num_labels = []#预测的结果
input_label = -127#输入采集label,默认为-127
timg = np.zeros((hight, width), np.uint8)#用于跟踪笔迹调试窗口
org_img = np.zeros((hight, width), np.uint8)#摄像头原始窗口
mnist_img = np.zeros((256, 256))#用于显示处理后的笔迹窗口,也是进行识别图像的扩大版
cv2.putText(mnist_img, "WestWell", (0, 64), cv2.FONT_HERSHEY_DUPLEX, 1.8, (255,255,255), 3, 3)#起始界面logo
#获取笔迹的数据和label
def get_trace_data(data_path):
data_out = []
label_out = []
files = os.listdir(data_path)
for x in files:
p = data_path + x
data = np.load(p)
data = data.flatten()
data_out.append(data)
label = int(x.split("_")[0])#取第一个为label
label_out.append(lable_to_vector[label])
data_out_mtx = np.vstack(data_out)
label_out_mtx = np.vstack(label_out)
np.random.seed(701507)#伪随机
numSum = data_out_mtx.shape[0]
shuffle_idx = np.random.permutation(numSum)
np.random.shuffle(shuffle_idx)
data_out_mtx = data_out_mtx[shuffle_idx,:]
label_out_mtx = label_out_mtx[shuffle_idx,:]
return data_out_mtx, label_out_mtx
#获取mnist的数据和label
def get_mnist_data(num):
data_out = []
label_out = []
dataset = mnist("train")#仅仅使用train的数据
for q in range (0, num):
label, img = dataset.GetImage(q)
data_out.append(img)
label_out.append(lable_to_vector[int(label)])
data_out_mtx = np.vstack(data_out)
label_out_mtx = np.vstack(label_out)
np.random.seed(701507)#随机种子,大质数
shuffle_idx = np.random.permutation(num)
np.random.shuffle(shuffle_idx)
data_out_mtx = data_out_mtx[shuffle_idx,:]
label_out_mtx = label_out_mtx[shuffle_idx,:]
return data_out_mtx, label_out_mtx
def pic_mnist(arry_img):
global mnist_img
kernel=np.uint8(np.zeros((5,5)))#膨胀的内核为5x5的十字
for x in range(5):
kernel[x,2]=1
kernel[2,x]=1
#arry_img需要再确保一次自己为二值化图,才能进入findContours
retval, arry_img = cv2.threshold(arry_img, 1, 255, cv2.THRESH_BINARY)
arry_img = arry_img.astype(np.uint8)
#寻找轮廓,剔除较小轮廓(比如小于60,丢弃)
contours, hierarchy = cv2.findContours(arry_img, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
if len(contours) == 0:
return None
back = np.zeros_like(arry_img, np.uint8)
ct_dict = {}
for i in range(len(contours)):
ct_dict[i] = len(contours[i])
ct_sd_list = sorted(ct_dict.items(), lambda x, y: cmp(x[1], y[1]), reverse=True)
for i in ct_sd_list:
if i[1] > 60:#大于60个点的轮廓进行显示
cv2.drawContours(back, contours, i[0], (255,255,255), -1)#在全黑背景图上,用白色填充轮廓
dilate = cv2.dilate(arry_img, kernel, iterations = 2)
back = cv2.bitwise_and(back, dilate)
row_order_nozero = (np.transpose(np.nonzero(back)))
col_order_nozero = row_order_nozero[np.lexsort(row_order_nozero[:,1:2].T)]
if row_order_nozero.shape[0] < 80:#非零数小于80个点,无效
return None
min_row_nozero = int(row_order_nozero[0,0:1])
max_row_nozero = int(row_order_nozero[-1,0:1])
min_col_nozero = int(col_order_nozero[0,1:2])
max_col_nozero = int(col_order_nozero[-1,1:2])
cut_img = back[min_row_nozero:max_row_nozero+1, min_col_nozero:max_col_nozero+1]
#填充为正方形
h, w = cut_img.shape
if w >= h:
h_pad = (w - h)/2
h_pad_ = h_pad + (w - h)%2
hp_top = np.zeros([h_pad, w])
hp_bot = np.zeros([h_pad_, w])
pad_img = np.vstack((hp_top, cut_img, hp_bot))
elif w < h:
w_pad = (h - w)/2
w_pad_ = w_pad + (h - w)%2
wp_lef = np.zeros([h, w_pad])
wp_rig = np.zeros([h, w_pad_])
pad_img = np.hstack((wp_lef, cut_img, wp_rig))
mnist = cv2.resize(pad_img,(28,28),interpolation=cv2.INTER_AREA)
mnist = mnist.astype(np.uint8)
retval, mnist = cv2.threshold(mnist, 1, 255, cv2.THRESH_BINARY)
mnist = mnist.astype(np.uint8)
#np.save(name, mnist)#采集数据
mnist_img = cv2.resize(pad_img, (256, 256), interpolation=cv2.INTER_AREA)
return mnist
#初始化deepwell和设置参数
l_scale = 6
d_scale = 3
y_scale = 0
h_size = 0x7ff#0x800 * 8 = 16K
dp.INIT()
dp.WD_EN(1)
clear_M_en = 1
'''
#先使用deepwell 训练和测试一遍mnist
datas, labels = get_mnist_data(10000)
org_labels = [x.argmax() for x in labels]
org_ar = np.array(org_labels)
tdatas = datas[0:1000,]
torg_ar = org_ar[0:1000,]
s = time.time()
dp.Train(datas, labels, l_scale, d_scale, y_scale, h_size, clear_M_en)
e = time.time()
print "deepwell train 10000 using time:%f"%(e - s)
dp.wait_for_idle(0)
s = time.time()
res_labels = dp.Test(tdatas, l_scale, d_scale, y_scale, h_size)
e = time.time()
print "deepwell test 1000 using time:%f"%(e - s)
res_ar = np.array(res_labels)
cmp_ar = torg_ar - res_ar
acc = float(np.sum(cmp_ar==0))/len(torg_ar)
print acc
'''
#再用deepwell训练和测试一遍本项目采集笔迹图(越靠后训练图,权重越大)
t_datas, t_labels = get_trace_data("./tdata/")
#'''
t_org_labels = [x.argmax() for x in t_labels]
t_org_ar = np.array(t_org_labels)
clear_M_en = 1
s = time.time()
dp.Train(t_datas, t_labels, l_scale, d_scale, y_scale, h_size, clear_M_en)
e = time.time()
print "deepwell train trace using time:%f"%(e - s)
dp.wait_for_idle(0)
s = time.time()
res_labels = dp.Test(t_datas, l_scale, d_scale, y_scale, h_size)
e = time.time()
print "deepwell test trace using time:%f"%(e - s)
res_ar = np.array(res_labels)
cmp_ar = t_org_ar - res_ar
acc = float(np.sum(cmp_ar==0))/len(t_org_ar)
print acc
dp.save_weights(0,"./base_weight")#验证读写weight
dp.load_weights(0,"./base_weight")
#'''
#dp.load_weights(0, "./current_weight")
#跟踪笔迹和识别的主线程
def detect_dpw():
global status, org_img, timg, mnist_img, num_labels, width, hight
global input_label, t_datas, t_labels
mnist_feature = np.zeros((28, 28), np.uint8)#用来训练和识别的feature图
mnist_feature = mnist_feature.reshape((1, -1))#feature转为一维向量
trace_status = 0
hit_contours = []
h=346#采集区域尺寸
w=404
kernel_ellipse = cv2.getStructuringElement(cv2.MORPH_ELLIPSE,(8,8))
cap = cv2.VideoCapture(1)
while True:
if status == -1:#退出
return
if status == 3:#处于采集数据阶段不进行跟踪
if input_label != -127:#input_label采集成功并且处于采集状态,那么把采集好的数据进行训练
ar = np.array(lable_to_vector[input_label])
ar = ar.reshape((1, -1))
t_datas = np.row_stack((t_datas, mnist_feature))
t_labels = np.row_stack((t_labels, ar))
#'''
np.random.seed(701507)
shuffle_idx = np.random.permutation(t_datas.shape[0])
np.random.shuffle(shuffle_idx)
t_datas = t_datas[shuffle_idx,:]
t_labels = t_labels[shuffle_idx,:]
#'''
dp.Train(t_datas, t_labels, l_scale, d_scale, y_scale, h_size, 0)
dp.wait_for_idle(0)
pre_label = dp.Test(mnist_feature, l_scale, d_scale, y_scale, h_size)
print pre_label
if pre_label[0] == input_label:#如果学习到此次采集的数据,保存weight退出采集,更新显示的label,否则一直循环学习同一个数据
print "learn success"
dp.save_weights(0, "./current_weight")
input_label = -127#设置回默认
status = 0
num_labels = pre_label
continue
ret, org_img = cap.read()
#采框放到视频正中心位置
x = width/2 - w/2
y = hight/2 - h/2
img_crop = org_img[y:y+h, x:x+w, :];
cv2.rectangle(org_img,(x, y),(x + w,y + h),(55,255,155),3)
hsv = cv2.cvtColor(img_crop, cv2.COLOR_BGR2HSV)
mask2 = cv2.inRange(hsv, np.array([2,50,50]), np.array([15,255,255]))#在hsv域上,采集皮肤颜色
erosion = cv2.erode(mask2, kernel_ellipse, iterations = 1) #腐蚀之后,转二值图
retval, binary = cv2.threshold(erosion, 15, 255, cv2.THRESH_BINARY)
nz = np.nonzero(binary)
if len(nz[0]) != 0:
ty = nz[0][0]
tx = nz[1][0]
hit_contours.append((tx, ty))
if hit_contours.count((tx, ty)) > 3:#在一个点持续发现3次,认为开始
tnz = np.nonzero(timg)
if len(tnz[0]) < 10:#timg没有东西的再刷新,防止中间停顿时间长,误删
timg = np.zeros((hight, width), np.uint8)
hit_contours = []
trace_status = 1
mnist_img = np.zeros((256, 256))#清除显示中的笔迹图
num_labels = []#清除上次预测的label
if trace_status == 1:
cv2.circle(timg,(tx,ty), 6, (255,255,255),-1)
else:#没有目标发现,连续20次,认为书写结束,开始识别
hit_contours.append((-1, -1))
if hit_contours.count((-1, -1)) > 20:
mnist_28x28 = pic_mnist(timg)
if mnist_28x28 != None:
mnist_feature = mnist_28x28.reshape((1,-1))
num_labels = dp.Test(mnist_feature, l_scale, d_scale, y_scale, h_size)
print num_labels#打印出识别结果
timg = np.zeros((hight, width), np.uint8)
hit_contours = []
trace_status = 0
#UI显示的主线程
displayer_w = 1920#显示器尺寸
displayer_h = 1080
back_array = np.zeros((displayer_h, displayer_w))#黑色
back_img = Image.new('RGBA', (displayer_w, displayer_h))
back_img.paste(Image.fromarray(np.uint8(back_array)), (0,0))#初始化为黑色背景
def show():
global status, org_img, timg, mnist_img, num_labels, back_img, input_label
cv2.namedWindow(".", cv2.cv.CV_WINDOW_NORMAL)
is_fullscreen = 1
cv2.setWindowProperty(".", 0, is_fullscreen)#只有smnist全屏
while True:
if status == -1:#退出
return
if is_fullscreen == 0:
cv2.imshow("show", org_img)#这两幅图,退出全屏,调试时候查看用
cv2.imshow("timg", timg)#
back_img.paste(Image.fromarray(np.uint8(timg)), (0, 0))#将轨迹图复制到左上角
back_img.paste(Image.fromarray(np.uint8(mnist_img)), (int(displayer_w/2 - 128), int(displayer_h/2 - 128)))#将结果复制到中心位置
back_img_cv2 = cv2.cvtColor(np.asarray(back_img),cv2.COLOR_RGB2BGR)
retval, back_img_cv2 = cv2.threshold(back_img_cv2, 1, 255, cv2.THRESH_BINARY)
back_img_cv2 = back_img_cv2.astype(np.uint8)
#将结果显示在
if status == 3:
cv2.putText(back_img_cv2, "[Training ... %d]"%(input_label), (int(displayer_w/2 + 256), int(displayer_h/2 + 256)), cv2.FONT_HERSHEY_PLAIN, 4.5, (255,255,255), 5, 5)
else:
if len(num_labels) != 0:
cv2.putText(back_img_cv2, str(num_labels), (int(displayer_w/2 + 256), int(displayer_h/2 + 256)), cv2.FONT_HERSHEY_PLAIN, 5.0, (255,255,255), 5, 5)
cv2.imshow(".", back_img_cv2)
key = cv2.waitKey(1)&0xFF
if key == ord('q'):#按键q,退出程序
status = -1;
if key == ord('c'):#按键c,采集数据
status = 3
if key == ord('s'):#按键s,强行恢复正常状态
status = 0
input_label = -127
if key == ord('0'):
input_label = 0
if key == ord('1'):
input_label = 1
if key == ord('2'):
input_label = 2
if key == ord('3'):
input_label = 3
if key == ord('4'):
input_label = 4
if key == ord('5'):
input_label = 5
if key == ord('6'):
input_label = 6
if key == ord('7'):
input_label = 7
if key == ord('8'):
input_label = 8
if key == ord('9'):
input_label = 9
if key == ord('f'):#按键f,退出/进入 全屏
if is_fullscreen == 1:
cv2.setWindowProperty(".", 0, 0)#退出全屏
is_fullscreen = 0
else:
cv2.setWindowProperty(".", 0, 1)#进入全屏
is_fullscreen = 1
#main函数开始,启动两个线程,分别用于ui显示和后端检测
ret = thread.start_new_thread(show, ())
ret = thread.start_new_thread(detect_dpw, ())
while status != -1:#如果status为-1,结束主进程的while循环,主进程退出后,两线程自动退出
pass