forked from bubbliiiing/unet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
predict.py
106 lines (95 loc) · 4.78 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
#----------------------------------------------------#
# 对视频中的predict.py进行了修改,
# 将单张图片预测、摄像头检测和FPS测试功能
# 整合到了一个py文件中,通过指定mode进行模式的修改。
#----------------------------------------------------#
import time
import cv2
import numpy as np
from PIL import Image
from unet import Unet
if __name__ == "__main__":
#-------------------------------------------------------------------------#
# 如果想要修改对应种类的颜色,到generate函数里修改self.colors即可
#-------------------------------------------------------------------------#
unet = Unet()
#-------------------------------------------------------------------------#
# mode用于指定测试的模式:
# 'predict'表示单张图片预测
# 'video'表示视频检测
# 'fps'表示测试fps
#-------------------------------------------------------------------------#
mode = "predict"
#-------------------------------------------------------------------------#
# video_path用于指定视频的路径,当video_path=0时表示检测摄像头
# video_save_path表示视频保存的路径,当video_save_path=""时表示不保存
# video_fps用于保存的视频的fps
# video_path、video_save_path和video_fps仅在mode='video'时有效
# 保存视频时需要ctrl+c退出才会完成完整的保存步骤,不可直接结束程序。
#-------------------------------------------------------------------------#
video_path = 0
video_save_path = ""
video_fps = 25.0
if mode == "predict":
'''
predict.py有几个注意点
1、该代码无法直接进行批量预测,如果想要批量预测,可以利用os.listdir()遍历文件夹,利用Image.open打开图片文件进行预测。
具体流程可以参考get_miou_prediction.py,在get_miou_prediction.py即实现了遍历。
2、如果想要保存,利用r_image.save("img.jpg")即可保存。
3、如果想要原图和分割图不混合,可以把blend参数设置成False。
4、如果想根据mask获取对应的区域,可以参考detect_image函数中,利用预测结果绘图的部分,判断每一个像素点的种类,然后根据种类获取对应的部分。
seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
for c in range(self.num_classes):
seg_img[:, :, 0] += ((pr == c)*( self.colors[c][0] )).astype('uint8')
seg_img[:, :, 1] += ((pr == c)*( self.colors[c][1] )).astype('uint8')
seg_img[:, :, 2] += ((pr == c)*( self.colors[c][2] )).astype('uint8')
'''
while True:
img = input('Input image filename:')
try:
image = Image.open(img)
except:
print('Open Error! Try again!')
continue
else:
r_image = unet.detect_image(image)
r_image.show()
elif mode == "video":
capture=cv2.VideoCapture(video_path)
if video_save_path!="":
fourcc = cv2.VideoWriter_fourcc(*'XVID')
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
fps = 0.0
while(True):
t1 = time.time()
# 读取某一帧
ref,frame=capture.read()
# 格式转变,BGRtoRGB
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
# 转变成Image
frame = Image.fromarray(np.uint8(frame))
# 进行检测
frame = np.array(unet.detect_image(frame))
# RGBtoBGR满足opencv显示格式
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
fps = ( fps + (1./(time.time()-t1)) ) / 2
print("fps= %.2f"%(fps))
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
cv2.imshow("video",frame)
c= cv2.waitKey(1) & 0xff
if video_save_path!="":
out.write(frame)
if c==27:
capture.release()
break
capture.release()
out.release()
cv2.destroyAllWindows()
elif mode == "fps":
test_interval = 100
img = Image.open('img/street.jpg')
tact_time = unet.get_FPS(img, test_interval)
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
else:
raise AssertionError("Please specify the correct mode: 'predict', 'video' or 'fps'.")