forked from bubbliiiing/unet-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
unet.py
167 lines (143 loc) · 7.43 KB
/
unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import colorsys
import copy
import time
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from torch import nn
from nets.unet import Unet as unet
#--------------------------------------------#
# 使用自己训练好的模型预测需要修改2个参数
# model_path和num_classes都需要修改!
# 如果出现shape不匹配
# 一定要注意训练时的model_path和num_classes数的修改
#--------------------------------------------#
class Unet(object):
_defaults = {
"model_path" : 'model_data/unet_voc.pth',
"model_image_size" : (512, 512, 3),
"num_classes" : 21,
"cuda" : True,
#--------------------------------#
# blend参数用于控制是否
# 让识别结果和原图混合
#--------------------------------#
"blend" : True
}
#---------------------------------------------------#
# 初始化UNET
#---------------------------------------------------#
def __init__(self, **kwargs):
self.__dict__.update(self._defaults)
self.generate()
#---------------------------------------------------#
# 获得所有的分类
#---------------------------------------------------#
def generate(self):
self.net = unet(num_classes=self.num_classes, in_channels=self.model_image_size[-1]).eval()
state_dict = torch.load(self.model_path)
self.net.load_state_dict(state_dict)
if self.cuda:
self.net = nn.DataParallel(self.net)
self.net = self.net.cuda()
print('{} model loaded.'.format(self.model_path))
if self.num_classes <= 21:
self.colors = [(0, 0, 0), (128, 0, 0), (0, 128, 0), (128, 128, 0), (0, 0, 128), (128, 0, 128), (0, 128, 128),
(128, 128, 128), (64, 0, 0), (192, 0, 0), (64, 128, 0), (192, 128, 0), (64, 0, 128), (192, 0, 128),
(64, 128, 128), (192, 128, 128), (0, 64, 0), (128, 64, 0), (0, 192, 0), (128, 192, 0), (0, 64, 128), (128, 64, 12)]
else:
# 画框设置不同的颜色
hsv_tuples = [(x / len(self.class_names), 1., 1.)
for x in range(len(self.class_names))]
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
self.colors = list(
map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)),
self.colors))
def letterbox_image(self ,image, size):
image = image.convert("RGB")
iw, ih = image.size
w, h = size
scale = min(w/iw, h/ih)
nw = int(iw*scale)
nh = int(ih*scale)
image = image.resize((nw,nh), Image.BICUBIC)
new_image = Image.new('RGB', size, (128,128,128))
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
return new_image,nw,nh
#---------------------------------------------------#
# 检测图片
#---------------------------------------------------#
def detect_image(self, image):
#---------------------------------------------------------#
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
#---------------------------------------------------------#
image = image.convert('RGB')
#---------------------------------------------------#
# 对输入图像进行一个备份,后面用于绘图
#---------------------------------------------------#
old_img = copy.deepcopy(image)
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
#---------------------------------------------------#
# 进行不失真的resize,添加灰条,进行图像归一化
#---------------------------------------------------#
image, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
images = [np.array(image)/255]
images = np.transpose(images,(0,3,1,2))
#---------------------------------------------------#
# 图片传入网络进行预测
#---------------------------------------------------#
with torch.no_grad():
images = torch.from_numpy(images).type(torch.FloatTensor)
if self.cuda:
images =images.cuda()
pr = self.net(images)[0]
#---------------------------------------------------#
# 取出每一个像素点的种类
#---------------------------------------------------#
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
#--------------------------------------#
# 将灰条部分截取掉
#--------------------------------------#
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
#------------------------------------------------#
# 创建一副新图,并根据每个像素点的种类赋予颜色
#------------------------------------------------#
seg_img = np.zeros((np.shape(pr)[0],np.shape(pr)[1],3))
for c in range(self.num_classes):
seg_img[:,:,0] += ((pr[:,: ] == c )*( self.colors[c][0] )).astype('uint8')
seg_img[:,:,1] += ((pr[:,: ] == c )*( self.colors[c][1] )).astype('uint8')
seg_img[:,:,2] += ((pr[:,: ] == c )*( self.colors[c][2] )).astype('uint8')
#------------------------------------------------#
# 将新图片转换成Image的形式
#------------------------------------------------#
image = Image.fromarray(np.uint8(seg_img)).resize((orininal_w,orininal_h))
#------------------------------------------------#
# 将新图片和原图片混合
#------------------------------------------------#
if self.blend:
image = Image.blend(old_img,image,0.7)
return image
def get_FPS(self, image, test_interval):
orininal_h = np.array(image).shape[0]
orininal_w = np.array(image).shape[1]
image, nw, nh = self.letterbox_image(image,(self.model_image_size[1],self.model_image_size[0]))
images = [np.array(image)/255]
images = np.transpose(images,(0,3,1,2))
with torch.no_grad():
images = torch.from_numpy(images).type(torch.FloatTensor)
if self.cuda:
images =images.cuda()
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
t1 = time.time()
for _ in range(test_interval):
with torch.no_grad():
pr = self.net(images)[0]
pr = F.softmax(pr.permute(1,2,0),dim = -1).cpu().numpy().argmax(axis=-1)
pr = pr[int((self.model_image_size[0]-nh)//2):int((self.model_image_size[0]-nh)//2+nh), int((self.model_image_size[1]-nw)//2):int((self.model_image_size[1]-nw)//2+nw)]
t2 = time.time()
tact_time = (t2 - t1) / test_interval
return tact_time