-
Notifications
You must be signed in to change notification settings - Fork 119
/
mnist_cnn_gui_main.py
192 lines (138 loc) · 5.61 KB
/
mnist_cnn_gui_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import sys, os
import numpy as np
from dataset.mnist import load_mnist
from PIL import Image, ImageQt
from qt.layout import Ui_MainWindow
from qt.paintboard import PaintBoard
from PyQt5.QtWidgets import QMainWindow, QDesktopWidget, QApplication
from PyQt5.QtWidgets import QLabel, QMessageBox, QPushButton, QFrame
from PyQt5.QtGui import QPainter, QPen, QPixmap, QColor, QImage
from PyQt5.QtCore import Qt, QPoint, QSize
from simple_convnet import SimpleConvNet
from common.functions import softmax
from deep_convnet import DeepConvNet
MODE_MNIST = 1 # MNIST随机抽取
MODE_WRITE = 2 # 手写输入
Thresh = 0.5 # 识别结果置信度阈值
# 读取MNIST数据集
(_, _), (x_test, _) = load_mnist(normalize=True, flatten=False, one_hot_label=False)
# 初始化网络
# 网络1:简单CNN
"""
conv - relu - pool - affine - relu - affine - softmax
"""
network = SimpleConvNet(input_dim=(1,28,28),
conv_param = {'filter_num': 30, 'filter_size': 5, 'pad': 0, 'stride': 1},
hidden_size=100, output_size=10, weight_init_std=0.01)
network.load_params("params.pkl")
# 网络2:深度CNN
# network = DeepConvNet()
# network.load_params("deep_convnet_params.pkl")
class MainWindow(QMainWindow,Ui_MainWindow):
def __init__(self):
super(MainWindow,self).__init__()
# 初始化参数
self.mode = MODE_MNIST
self.result = [0, 0]
# 初始化UI
self.setupUi(self)
self.center()
# 初始化画板
self.paintBoard = PaintBoard(self, Size = QSize(224, 224), Fill = QColor(0,0,0,0))
self.paintBoard.setPenColor(QColor(0,0,0,0))
self.dArea_Layout.addWidget(self.paintBoard)
self.clearDataArea()
# 窗口居中
def center(self):
# 获得窗口
framePos = self.frameGeometry()
# 获得屏幕中心点
scPos = QDesktopWidget().availableGeometry().center()
# 显示到屏幕中心
framePos.moveCenter(scPos)
self.move(framePos.topLeft())
# 窗口关闭事件
def closeEvent(self, event):
reply = QMessageBox.question(self, 'Message',
"Are you sure to quit?", QMessageBox.Yes |
QMessageBox.No, QMessageBox.Yes)
if reply == QMessageBox.Yes:
event.accept()
else:
event.ignore()
# 清除数据待输入区
def clearDataArea(self):
self.paintBoard.Clear()
self.lbDataArea.clear()
self.lbResult.clear()
self.lbCofidence.clear()
self.result = [0, 0]
"""
回调函数
"""
# 模式下拉列表回调
def cbBox_Mode_Callback(self, text):
if text == '1:MINIST随机抽取':
self.mode = MODE_MNIST
self.clearDataArea()
self.pbtGetMnist.setEnabled(True)
self.paintBoard.setBoardFill(QColor(0,0,0,0))
self.paintBoard.setPenColor(QColor(0,0,0,0))
elif text == '2:鼠标手写输入':
self.mode = MODE_WRITE
self.clearDataArea()
self.pbtGetMnist.setEnabled(False)
# 更改背景
self.paintBoard.setBoardFill(QColor(0,0,0,255))
self.paintBoard.setPenColor(QColor(255,255,255,255))
# 数据清除
def pbtClear_Callback(self):
self.clearDataArea()
# 识别
def pbtPredict_Callback(self):
__img, img_array =[],[] # 将图像统一从qimage->pil image -> np.array [1, 1, 28, 28]
# 获取qimage格式图像
if self.mode == MODE_MNIST:
__img = self.lbDataArea.pixmap() # label内若无图像返回None
if __img == None: # 无图像则用纯黑代替
# __img = QImage(224, 224, QImage.Format_Grayscale8)
__img = ImageQt.ImageQt(Image.fromarray(np.uint8(np.zeros([224,224]))))
else: __img = __img.toImage()
elif self.mode == MODE_WRITE:
__img = self.paintBoard.getContentAsQImage()
# 转换成pil image类型处理
pil_img = ImageQt.fromqimage(__img)
pil_img = pil_img.resize((28, 28), Image.ANTIALIAS)
# pil_img.save('test.png')
img_array = np.array(pil_img.convert('L')).reshape(1,1,28, 28) / 255.0
# img_array = np.where(img_array>0.5, 1, 0)
# reshape成网络输入类型
__result = network.predict(img_array) # shape:[1, 10]
# print (__result)
# 将预测结果使用softmax输出
__result = softmax(__result)
self.result[0] = np.argmax(__result) # 预测的数字
self.result[1] = __result[0, self.result[0]] # 置信度
self.lbResult.setText("%d" % (self.result[0]))
self.lbCofidence.setText("%.8f" % (self.result[1]))
# 随机抽取
def pbtGetMnist_Callback(self):
self.clearDataArea()
# 随机抽取一张测试集图片,放大后显示
img = x_test[np.random.randint(0, 9999)] # shape:[1,28,28]
img = img.reshape(28, 28) # shape:[28,28]
img = img * 0xff # 恢复灰度值大小
pil_img = Image.fromarray(np.uint8(img))
pil_img = pil_img.resize((224, 224)) # 图像放大显示
# 将pil图像转换成qimage类型
qimage = ImageQt.ImageQt(pil_img)
# 将qimage类型图像显示在label
pix = QPixmap.fromImage(qimage)
self.lbDataArea.setPixmap(pix)
if __name__ == "__main__":
app = QApplication(sys.argv)
Gui = MainWindow()
Gui.show()
sys.exit(app.exec_())