-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_30TestMain.py
44 lines (37 loc) · 1.42 KB
/
_30TestMain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
#-*- coding:utf-8 _*-
"""
@Author : Xiaoqi Cheng
@Time : 2021/1/15 9:52
"""
import torch
from _99Normalization import *
from _31TestMultiPipeDatasetLoader import *
from _03FCN import *
from _21CalEvaluationMetrics import *
# %% Load dataset and model
FolderPath = 'PotData'
TestDataset, TestDataLoader = PipeDatasetLoader(FolderPath, Width = 2, ShowSample=False)
Model = Net(InputChannels=2, OutputChannels=1, InitFeatures=32, WithActivateLast=True, ActivateFunLast=torch.sigmoid).cuda()
Model.load_state_dict(torch.load('1PotData_Width2_BCE_Input2/0700.pt', map_location = 'cuda'))
SaveFolder = 'TestResult'
# %% Testing
Model.eval()
torch.set_grad_enabled(False)
for Iter, (Input, TMImg, SampleName) in enumerate(TestDataLoader):
print(SampleName)
Input = torch.cat(Input, dim=1)
InputImg = Input.float().to('cuda')
OutputImg = Model(InputImg)
# Generate result image
OutputImg = OutputImg.cpu().numpy()[0, 0]
OutputImg = (OutputImg*255).astype(np.uint8)
TMImg = TMImg.numpy()[0][0]
TMImg = (Normalization(TMImg) * 255).astype(np.uint8)
ResultImg = cv2.cvtColor(TMImg, cv2.COLOR_GRAY2RGB)
# ResultImg[...,2] = OutputImg
contour_points = np.argwhere(OutputImg > 100)
ResultImg[contour_points[:, 0], contour_points[:, 1], 2] = 255
# plt.imshow(OutputImg)
# plt.show()
os.makedirs(SaveFolder, exist_ok=True)
cv2.imwrite(os.path.join(SaveFolder, SampleName[0] + '.png'), ResultImg)