forked from Shiaoming/Python-VO
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
112 lines (81 loc) · 3.31 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import numpy as np
import cv2
import argparse
import yaml
import logging
from utils.tools import plot_keypoints
from DataLoader import create_dataloader
from Detectors import create_detector
from Matchers import create_matcher
from VO.VisualOdometry import VisualOdometry, AbosluteScaleComputer
from tqdm import tqdm
def keypoints_plot(img, vo):
if img.shape[2] == 1:
img = cv2.cvtColor(img, cv2.COLOR_GRAY2RGB)
return plot_keypoints(img, vo.kptdescs["cur"]["keypoints"], vo.kptdescs["cur"]["scores"])
class TrajPlotter(object):
def __init__(self):
self.errors = []
self.traj = np.zeros((600, 600, 3), dtype=np.uint8)
pass
def update(self, est_xyz, gt_xyz):
x, z = est_xyz[0], est_xyz[2]
gt_x, gt_z = gt_xyz[0], gt_xyz[2]
est = np.array([x, z]).reshape(2)
gt = np.array([gt_x, gt_z]).reshape(2)
error = np.linalg.norm(est - gt)
self.errors.append(error)
avg_error = np.mean(np.array(self.errors))
# === drawer ==================================
# each point
draw_x, draw_y = int(x) + 290, int(z) + 90
true_x, true_y = int(gt_x) + 290, int(gt_z) + 90
# draw trajectory
cv2.circle(self.traj, (draw_x, draw_y), 1, (0, 255, 0), 1)
cv2.circle(self.traj, (true_x, true_y), 1, (0, 0, 255), 2)
cv2.rectangle(self.traj, (10, 20), (600, 80), (0, 0, 0), -1)
# draw text
text = "[AvgError] %2.4fm" % (avg_error)
cv2.putText(self.traj, text, (20, 40),
cv2.FONT_HERSHEY_PLAIN, 1, (255, 255, 255), 1, 8)
return self.traj
def run(args):
with open(args.config, 'r') as f:
config = yaml.full_load(f)
# create dataloader
loader = create_dataloader(config["dataset"])
# create detector
detector = create_detector(config["detector"])
# create matcher
matcher = create_matcher(config["matcher"])
absscale = AbosluteScaleComputer()
traj_plotter = TrajPlotter()
# log
fname = args.config.split('/')[-1].split('.')[0]
log_fopen = open("results/" + fname + ".txt", mode='a')
vo = VisualOdometry(detector, matcher, loader.cam)
pbar = tqdm(4)
for i, img in enumerate(loader):
gt_pose = loader.get_cur_pose()
R, t = vo.update(img, absscale.update(gt_pose))
# === log writer ==============================
print(i, t[0, 0], t[1, 0], t[2, 0], gt_pose[0, 3], gt_pose[1, 3], gt_pose[2, 3], file=log_fopen)
# === drawer ==================================
img1 = keypoints_plot(img, vo)
img2 = traj_plotter.update(t, gt_pose[:, 3])
cv2.imshow("keypoints", img1)
cv2.imshow("trajectory", img2)
if cv2.waitKey(10) == 27:
break
pbar.update()
cv2.imwrite("results/" + fname + '.png', img2)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='python_vo')
parser.add_argument('--config', type=str, default='params/kitti_superpoint_supergluematch.yaml',
help='config file')
parser.add_argument('--logging', type=str, default='INFO',
help='logging level: NOTSET, DEBUG, INFO, WARNING, ERROR, CRITICAL')
args = parser.parse_args()
logging.basicConfig(level=logging._nameToLevel[args.logging])
run(args)