forked from vietnh1009/Yolo-v2-pytorch
-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_voc_images.py
88 lines (79 loc) · 3.92 KB
/
test_voc_images.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""
@author: Viet Nguyen <[email protected]>
"""
import os
import glob
import argparse
import pickle
import cv2
import numpy as np
from src.utils import *
from src.yolo_net import Yolo
CLASSES = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow',
'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
'tvmonitor']
def get_args():
parser = argparse.ArgumentParser("You Only Look Once: Unified, Real-Time Object Detection")
parser.add_argument("--image_size", type=int, default=448, help="The common width and height for all images")
parser.add_argument("--conf_threshold", type=float, default=0.35)
parser.add_argument("--nms_threshold", type=float, default=0.5)
parser.add_argument("--pre_trained_model_type", type=str, choices=["model", "params"], default="model")
parser.add_argument("--pre_trained_model_path", type=str, default="trained_models/whole_model_trained_yolo_voc")
parser.add_argument("--input", type=str, default="test_images")
parser.add_argument("--output", type=str, default="test_images")
args = parser.parse_args()
return args
def test(opt):
if torch.cuda.is_available():
if opt.pre_trained_model_type == "model":
model = torch.load(opt.pre_trained_model_path)
else:
model = Yolo(20)
model.load_state_dict(torch.load(opt.pre_trained_model_path))
else:
if opt.pre_trained_model_type == "model":
model = torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage)
else:
model = Yolo(20)
model.load_state_dict(torch.load(opt.pre_trained_model_path, map_location=lambda storage, loc: storage))
model.eval()
colors = pickle.load(open("src/pallete", "rb"))
for image_path in glob.iglob(opt.input + os.sep + '*.jpg'):
if "prediction" in image_path:
continue
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
height, width = image.shape[:2]
image = cv2.resize(image, (opt.image_size, opt.image_size))
image = np.transpose(np.array(image, dtype=np.float32), (2, 0, 1))
image = image[None, :, :, :]
width_ratio = float(opt.image_size) / width
height_ratio = float(opt.image_size) / height
data = Variable(torch.FloatTensor(image))
if torch.cuda.is_available():
data = data.cuda()
with torch.no_grad():
logits = model(data)
predictions = post_processing(logits, opt.image_size, CLASSES, model.anchors, opt.conf_threshold,
opt.nms_threshold)
if len(predictions) != 0:
predictions = predictions[0]
output_image = cv2.imread(image_path)
for pred in predictions:
xmin = int(max(pred[0] / width_ratio, 0))
ymin = int(max(pred[1] / height_ratio, 0))
xmax = int(min((pred[0] + pred[2]) / width_ratio, width))
ymax = int(min((pred[1] + pred[3]) / height_ratio, height))
color = colors[CLASSES.index(pred[5])]
cv2.rectangle(output_image, (xmin, ymin), (xmax, ymax), color, 2)
text_size = cv2.getTextSize(pred[5] + ' : %.2f' % pred[4], cv2.FONT_HERSHEY_PLAIN, 1, 1)[0]
cv2.rectangle(output_image, (xmin, ymin), (xmin + text_size[0] + 3, ymin + text_size[1] + 4), color, -1)
cv2.putText(
output_image, pred[5] + ' : %.2f' % pred[4],
(xmin, ymin + text_size[1] + 4), cv2.FONT_HERSHEY_PLAIN, 1,
(255, 255, 255), 1)
print("Object: {}, Bounding box: ({},{}) ({},{})".format(pred[5], xmin, xmax, ymin, ymax))
cv2.imwrite(image_path[:-4] + "_prediction.jpg", output_image)
if __name__ == "__main__":
opt = get_args()
test(opt)