-
Notifications
You must be signed in to change notification settings - Fork 20
/
test.py
113 lines (90 loc) · 4.37 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import os
import sys
import time
import datetime
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from torch.utils.data import DataLoader
import torchvision
# 切换当前工作目录
# os.chdir('/content/drive/My Drive/SignLanguageRecognition')
# import 子模块
from nnet.blstm import blstm
from utils.logger import *
from utils.parse_config import *
from utils.utils import *
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument("--model_config", type=str, default="config/Net.cfg", help="path to model definition file")
parser.add_argument("--model_name", type=str, default="blstm", help="used model name (lstm, blstm)")
parser.add_argument("--data_config", type=str, default="config/SLR_dataset.cfg", help="path to data config file")
parser.add_argument("--crop_size", type=int, default=256, help="size of each crop image")
opt = parser.parse_args()
print(opt)
# 读取配置文件
data_config = parse_data_config(opt.data_config)
model_config = parse_model_config(opt.model_config)[opt.model_name]
# 记录日志
logger = Logger(data_config["log_path"])
# 设置GPU
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# 设置随机种子
torch.manual_seed(int(model_config["SEED"]))
# 读取数据,转换为Tensor
dataset_dir = data_config["dataset_dir"]
np_data_x = np.load(os.path.join(dataset_dir, data_config['data_file_name']), allow_pickle=True)
np_data_y = np.load(os.path.join(dataset_dir, data_config['label_file_name']), allow_pickle=True)
data_x = torch.from_numpy(np_data_x)
data_y = torch.from_numpy(np_data_y)
# 数据集
data_len = len(data_x)
test_data_num = int(data_len * float(data_config['test_data_size']))
test_x = data_x[data_len - test_data_num:]
test_y = data_y[data_len - test_data_num:]
logger.logger.info("Test size: " + str(test_data_num))
# 处理模型参数
batch_size = int(model_config["BATCH_SIZE"])
cpu_nums = int(model_config["CPU_NUMS"])
time_step = int(model_config["TIME_STEP"])
input_size = int(model_config["INPUT_SIZE"])
output_size = int(model_config["OUTPUT_SIZE"])
# 保存的模型名称
model_save_name = opt.model_name + "_output" + str(output_size) + "_input" + str(time_step) + "x" + str(input_size) + ".pkl"
# 判断模型文件是否存在
model_save_dir = data_config["model_save_dir"]
model_save_path = os.path.join(model_save_dir, model_save_name)
if not os.path.exists(model_save_path):
logger.logger.error("model file is not existed!")
exit()
# 最外层是list,次外层是tuple,内层都是ndarray
data_test = list(test_x.numpy().reshape(1,-1, time_step, input_size))
data_test.append(list(test_y.numpy().reshape(-1, 1)))
data_test = list(zip(*data_test))
# 创建DataLoader
test_loader = DataLoader(data_test, batch_size=batch_size, num_workers=cpu_nums, pin_memory=True, shuffle=False)
# 测试
best_model = torch.load(os.path.join(model_save_dir, model_save_name)).get('model').to(device)
# 开启测试模式
best_model.eval()
final_predict = []
ground_truth = []
for step, (b_x, b_y) in enumerate(test_loader):
b_x = b_x.type(torch.FloatTensor).to(device)
b_y = b_y.type(torch.long).to(device)
with torch.no_grad():
prediction = best_model(b_x) # rnn output
# h_s = h_s.data # repack the hidden state, break the connection from last iteration
# h_c = h_c.data # repack the hidden state, break the connection from last iteration
ground_truth = ground_truth + b_y.view(b_y.size()[0]).cpu().numpy().tolist()
pre_result = torch.max(F.softmax(prediction[:, -1, :], dim=1), 1)
pre_class = pre_result[1].cpu().data.numpy().tolist()
# pre_prob = pre_result[0].cpu().data.numpy().tolist()
# print(pre_class, pre_prob)
final_predict = final_predict + pre_class
ground_truth = np.asarray(ground_truth)
final_predict = np.asarray(final_predict)
accuracy = float((ground_truth == final_predict).astype(int).sum()) / float(final_predict.size)
logger.logger.info("Test accuracy: " + str(accuracy))