-
Notifications
You must be signed in to change notification settings - Fork 0
/
feature_matcher.py
74 lines (52 loc) · 2.61 KB
/
feature_matcher.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import cv2
import numpy as np
import matplotlib.pyplot as plt
class FeatureMatcher:
def __init__(self, train_img, min_hessian=400, min_matches=10):
self._train = cv2.imread(train_img, 0)
self._orig_train = cv2.imread(train_img)
self._min_matches = min_matches
self._flann_trees = 0
self.SURF = cv2.xfeatures2d.SURF_create(min_hessian)
self._res = None
def match(self, query_img, kd_trees=5):
query = cv2.imread(query_img, 0)
orig_query = cv2.imread(query_img)
kp_t, desc_t = self.SURF.detectAndCompute(self._train, None)
kp_q, desc_q = self.SURF.detectAndCompute(query, None)
FLANN_INDEX_KDTREE = 0
index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=kd_trees)
search_params = dict(checks=50)
flann = cv2.FlannBasedMatcher(index_params, search_params)
matches = flann.knnMatch(desc_q, desc_t, k=2)
good_matches = []
for m, n in matches:
if m.distance < 0.7 * n.distance:
good_matches.append(m)
if len(good_matches) > self._min_matches:
src_pts = np.float32([kp_q[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
dst_pts = np.float32([kp_t[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)
M, mask = cv2.findHomography(src_pts, dst_pts, cv2.RANSAC, 5.0)
matchesMask = mask.ravel().tolist()
h, w = query.shape
pts = np.float32([[0, 0], [0, h-1], [w-1, h-1], [w-1, 0]]).reshape(-1, 1, 2)
dst = cv2.perspectiveTransform(pts, M)
query = cv2.polylines(query, [np.int32(dst)], True, 255, 3, cv2.LINE_AA)
draw_params = dict(matchColor=(0, 255, 0), singlePointColor=None,
matchesMask=matchesMask, flags=2)
self._res = cv2.drawMatches(orig_query, kp_q, self._orig_train, kp_t, good_matches, None, **draw_params)
else:
print("Couldn't find enough matches: {}".format(len(good_matches),
self._min_matches))
def save_match(self, path="Result"):
cv2.imwrite("result.jpg", self._res)
return True
def show_match(self):
plt.imshow(self._res), plt.show()
if __name__ == "__main__":
train_img = input("Enter your training image (please provide full path): ")
query_img = input("Enter your query image (please provide full path): ")
job = FeatureMatcher(train_img)
job.match(query_img)
job.save_match()
job.show_match()