-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy path_20ValiMain.py
74 lines (62 loc) · 2.59 KB
/
_20ValiMain.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#-*- coding:utf-8 _*-
"""
Validatation
@Author : Xiaoqi Cheng
@Time : 2021/1/11 9:52
"""
import torch,glob
from _99Normalization import *
from _99SaveLoad import *
from _02MultiPipeDatasetLoader import *
from _03FCN import *
from _21CalEvaluationMetrics import *
SaveFolder = "3PotData_Input2(4096)_IF8_Epoch2800_Dilation2_kernel5_BigKernel4_SmlKernel4"
Width = 2
# %% Load METCD and model
FolderPath = '../PotData_1'
TrainDataset, TrainDataLoader, ValDataset, ValDataLoader = PipeDatasetLoader(FolderPath, TrainBatchSize=1, ValBatchSize=1,
TrainNumWorkers=0, ValNumWorkers=0, Width=Width)
ModelNames = ['2800']
for ModelName in ModelNames:
SaveFilePath = os.path.join(SaveFolder, 'result' + ModelName + '.txt')
if os.path.exists(SaveFilePath):
print(SaveFilePath+' already exist!')
continue
Model = Net(InputChannels=2, OutputChannels=1, InitFeatures=8, WithActivateLast=True, ActivateFunLast=torch.sigmoid).cuda()
Model.load_state_dict(torch.load(os.path.join(SaveFolder, ModelName+'.pt'), map_location = 'cuda'))
# %% Evaluation
Model.eval()
torch.set_grad_enabled(False)
OutputS = []
LabelS = []
for Iter, (Input, Label, TMImg, SampleName) in enumerate(ValDataLoader):
# print(SampleName)
Input = torch.cat(Input, dim=1)
InputImg = Input.float().to('cuda')
OutputImg = Model(InputImg)
# Record
Output = OutputImg.detach().cpu().numpy()[0]
Label = Label.detach().cpu().numpy()[0]
OutputS.append(Output)
LabelS.append(Label)
OutputImg = OutputImg.cpu().numpy()[0, 0]
OutputImg = (OutputImg*255).astype(np.uint8)
TMImg = TMImg.numpy()[0][0]
TMImg = (Normalization(TMImg) * 255).astype(np.uint8)
ResultImg = cv2.cvtColor(TMImg, cv2.COLOR_GRAY2RGB)
LabelImg = (Normalization(Label[0]) * 255).astype(np.uint8)
ResultImg[..., 2] = cv2.add(ResultImg[..., 2], OutputImg)
# %% Calculate evaluation metrics
OutputFlatten = np.vstack(OutputS).ravel() #按垂直方向(行顺序)堆叠数组构成一个新的数组 .ravel()拉成一维
LabelFlatten = np.vstack(LabelS).ravel()
_, _, MF, mAP = PRC_mAP_MF(LabelFlatten, OutputFlatten, ShowPRC = True)
print('MF:', MF)
print('mAP:', mAP)
# DIA-ODS
DIA_ODS, OptThreshold = DIA_ODS(OutputS, LabelS, ShowCurve = True)
print('DIA-ODS:',DIA_ODS, ' at threshold:', OptThreshold)
with open(SaveFilePath, 'w') as f:
# f.write('AUC: '+str(AUC)+'\n')
f.write('MF: '+str(MF)+'\n')
f.write('mAP: '+str(mAP)+'\n')
f.write('DIA-ODS: '+str(DIA_ODS)+' at threshold: '+str(OptThreshold)+'\n')