This repository has been archived by the owner on Jun 5, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 134
/
single_video_reid.py
519 lines (429 loc) · 19.1 KB
/
single_video_reid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
# coding=utf-8
# given MOT results of each video, use ReID features to re-associate some tracks
# ignore switcher inside each track
# associate similar tracks within a range of frames, within some IOU range
# by using pretrained Person-ReID/vehicle-reID features
import argparse
import cv2
import os
import operator
import numpy as np
from tqdm import tqdm
from PIL import Image
import matplotlib
# avoid the warning "gdk_cursor_new_for_display:
# assertion 'GDK_IS_DISPLAY (display)' failed" with Python 3
matplotlib.use('Agg')
import torch
from utils import parse_meva_clip_name, expand_tlwh, tlwh_intersection
from torchreid.feature_extractor import FeatureExtractor
from torchreid.distance import compute_distance_matrix
from enqueuer_thread import VideoEnqueuer
from diva_io.video import VideoReader
from moviepy.editor import VideoFileClip
from utils import sec2time, valid_box
parser = argparse.ArgumentParser()
parser.add_argument("filepath", help="mot track result path")
parser.add_argument("videonamelst")
parser.add_argument("video_path")
parser.add_argument("newfilepath")
parser.add_argument("--gpuid", default=0, type=int,
help="gpu id")
parser.add_argument("--vehicle_reid_model", default=None)
parser.add_argument("--person_reid_model", default=None)
parser.add_argument("--use_lijun_video_loader", action="store_true")
parser.add_argument("--use_moviepy", action="store_true")
parser.add_argument("--reject_track_conf_thres", default=-1,
help="whether to reject track with avg conf score lower")
parser.add_argument("--tol_num_frame", default=200, type=int,
help="search for tracklet in the next 200/fps seconds.")
parser.add_argument("--expand_width_p", default=2.0, type=float)
parser.add_argument("--expand_height_p", default=0.5, type=float)
parser.add_argument("--feature_box_num", default=500, type=int,
help="maximum box num to use for feature extraction, -1 "
"means all")
parser.add_argument("--feature_box_gap", default=20, type=int,
help="interval when getting boxes")
parser.add_argument("--reject_dist_thres", default=150, type=float,
help="eu dist rejecting dist threshold")
parser.add_argument("--compare_method", default="min_all",
help="min_all / avg")
parser.add_argument("--max_size", type=int, default=1920)
parser.add_argument("--short_edge_size", type=int, default=1080)
parser.add_argument("--debug", action="store_true")
def preprocess(track_file, tol_num_frame=30,
expand_width_p=0.1, expand_height_p=0.1,
reject_track_conf_thres=-1):
# 1. read the tracks
# assuming sorted by frameid
data = []
with open(track_file, "r") as f:
for line in f:
frame_idx, track_id, left, top, width, height, conf, _, _, _ = line.strip().split(",")
data.append([frame_idx, track_id, left, top, width, height, conf])
if not data:
return 0, [], {}
data = np.array(data, dtype="float32") # [N, 7]
track_ids = np.unique(data[:, 1]).tolist()
track_data = {} # [num_track, K, 7]
for track_id in track_ids:
track_data[track_id] = data[data[:, 1] == track_id, :]
ori_track_num = len(track_data)
if reject_track_conf_thres > 0.:
good_track_ids = []
for track_id in track_data:
avg_conf = np.mean(track_data[track_id][:, 6])
if avg_conf > reject_track_conf_thres:
good_track_ids.append(track_id)
track_data = {track_data[t] for t in good_track_ids}
# 2. find possible linkage tracklet pairs
# assuming only happens at the end of each tracklet, within a num of frames
# n by n search
possible_pairs = {} # track_id -> a list of track_id
for idx1, track_id1 in enumerate(track_ids):
#print(track_id, track_data[idx][0, 0])
# track_id, first frame in ascending order already
# check this track's ending
# against other track's start within tol_num_frame, then check intersection
# with the expanded box
end_frame_idx1 = track_data[track_id1][-1, 0]
tlwh1 = track_data[track_id1][-1, 2:6] # the last box
expanded_tlwh1 = expand_tlwh(
tlwh1, expand_width_p, expand_height_p)
for idx2 in range(idx1+1, len(track_ids)):
track_id2 = track_ids[idx2]
start_frame_idx2 = track_data[track_id2][0, 0]
if start_frame_idx2 <= end_frame_idx1:
continue
if start_frame_idx2 - end_frame_idx1 < tol_num_frame:
# only check the start frame's box for this track
tlwh2 = track_data[track_id2][0, 2:6] # the first box
box_inter = tlwh_intersection(expanded_tlwh1, tlwh2)
if box_inter > 0:
if track_id1 not in possible_pairs:
possible_pairs[track_id1] = []
possible_pairs[track_id1].append(
[track_id2, start_frame_idx2, end_frame_idx1])
else:
# some track's end might be matched to multiple track within
# tolerance time, we only keep the earliest ones,
prev_start_frame_idx = possible_pairs[track_id1][0][1]
if start_frame_idx2 == prev_start_frame_idx:
possible_pairs[track_id1].append(
[track_id2, start_frame_idx2, end_frame_idx1])
return ori_track_num, track_data, possible_pairs
def reid(frame_iter, p_extractor, v_extractor,
p_track_data, p_candidates,
v_track_data, v_candidates,
feature_box_num=5, feature_box_gap=20,
reject_dist_thres=200,
compare_method="min_all"):
# get new track data for two classes
# track_data: track_id -> [K, 7], candidates: track_id -> a list of track_ids
# 1. get the frame_id=>boxes needed to extract feature based on candidates
frame_data = {}
needed_track_boxes = {} # "query_track_id"/"gallery_track_id" => boxes
def get_track_boxes(track_data, candidates):
for query_track_id in candidates:
query_key = "query_%s" % query_track_id
needed_track_boxes[query_key] = []
if feature_box_num <= 0:
# get all box (under frame_gap)
this_box_num_limit = len(track_data[query_track_id])
else:
this_box_num_limit = feature_box_num
# for query, get last few boxes
idx = 1
while idx <= len(track_data[query_track_id]) and \
len(needed_track_boxes[query_key]) < this_box_num_limit:
# get from the end of the tracklet
box = track_data[query_track_id][-idx] # [7]
needed_track_boxes[query_key].append(box)
idx += feature_box_gap
for gallery_track_id, _, _ in candidates[query_track_id]:
gallery_key = "gallery_%s" % gallery_track_id
if gallery_key not in needed_track_boxes:
needed_track_boxes[gallery_key] = []
if feature_box_num <= 0:
# get all box (under frame_gap)
this_box_num_limit = len(track_data[gallery_track_id])
else:
this_box_num_limit = feature_box_num
# for gallery, get first few boxes
idx = 0
while idx < len(track_data[gallery_track_id]) and \
len(needed_track_boxes[gallery_key]) < this_box_num_limit:
# get from the end of the tracklet
box = track_data[gallery_track_id][idx] # [7]
needed_track_boxes[gallery_key].append(box)
idx += feature_box_gap
get_track_boxes(p_track_data, p_candidates)
get_track_boxes(v_track_data, v_candidates)
for key in needed_track_boxes:
for box_idx, box in enumerate(needed_track_boxes[key]):
# [7],
frame_idx = box[0]
tlwh = box[2:6]
if not frame_idx in frame_data:
frame_data[frame_idx] = []
frame_data[frame_idx].append((tlwh, key, box_idx))
# 2. go through the video once and crop all the images (RAM should fit)
needed_track_box_imgs = {} # "query_track_id"/"gallery_track_id" => boxes
for batch in tqdm(frame_iter[0], total=frame_iter[1]):
image, scale, frame_idx = batch[0]
image = image.astype("uint8") # need uint8 type
if frame_idx in frame_data:
for tlwh, key, box_idx in frame_data[frame_idx]:
if key not in needed_track_box_imgs:
needed_track_box_imgs[key] = {}
# check box valid
if valid_box(tlwh, image):
x, y, w, h = tlwh
x, y, w, h = int(x), int(y), int(w), int(h)
#print(x, y, w, h)
#print(image[y:y+h, x:x+w])
box_img = cv2.cvtColor(
image[y:y+h, x:x+w], cv2.COLOR_BGR2RGB)
needed_track_box_imgs[key][box_idx] = box_img
# 3. for each query and gallery tracklets, extract feature and compare
# add the distance to the candidates
def get_features_and_compare(candidates, extractor, method="min_all"):
# method: "min_all" / "avg"
# 1. minimum among N (query_tracklet) by M (one gallery tracklet) dist
# 2. avg pool feature and then compute eu dist
assert method in ["min_all", "avg"]
remove_track_ids = {}
for query_track_id in candidates:
query_key = "query_%s" % query_track_id
query_boxes = needed_track_box_imgs[query_key] # box_idx => img [H, W, 3]
if not query_boxes:
remove_track_ids[query_track_id] = 1
continue
# front more important
box_idxs = sorted(list(query_boxes.keys()))
# torch tensor
query_features = extractor([query_boxes[i] for i in box_idxs]) # [K, 512]
gallery_featuress = []
# check the gallery boxes first
temp_list = []
for gallery_track_id, o, c in candidates[query_track_id]:
gallery_key = "gallery_%s" % gallery_track_id
gallery_boxes = needed_track_box_imgs[gallery_key] # box_idx => img
if gallery_boxes:
temp_list.append([gallery_track_id, o, c])
if not temp_list:
remove_track_ids[query_track_id] = 1
continue
candidates[query_track_id] = temp_list
for gallery_track_id, _, _ in candidates[query_track_id]:
gallery_key = "gallery_%s" % gallery_track_id
gallery_boxes = needed_track_box_imgs[gallery_key] # box_idx => img
assert gallery_boxes
# front more important
box_idxs = sorted(list(gallery_boxes.keys()))
gallery_features = extractor([gallery_boxes[i] for i in box_idxs])
#features = features.cpu().numpy()
gallery_featuress.append(gallery_features)
assert gallery_featuress
# # N query frames, M gallery tracklet with each K frames
# query_features [N, 512]
# gallery_featuress M of [K, 512], K is variable
if method == "avg":
# avg pooling
# [1, 512]
query_feature = torch.mean(query_features, 0, keepdim=True)
for i in range(len(gallery_featuress)):
gallery_featuress[i] = torch.mean(gallery_featuress[i], 0)
# [K, 512]
gallery_featuress = torch.stack(gallery_featuress)
# [1, K]
distmat = compute_distance_matrix(
query_feature, gallery_featuress, metric="euclidean")
elif method == "min_all":
# compute all query frame with all frames in each gallery track (N by K)
# take minimum
distmat = torch.zeros((1, len(gallery_featuress)))
for i, gallery_features in enumerate(gallery_featuress):
# [N, K]
this_distmat = compute_distance_matrix(
query_features, gallery_features, metric="euclidean")
distmat[0, i] = torch.min(this_distmat)
distmat = distmat.cpu().numpy()
for i in range(distmat.shape[1]):
candidates[query_track_id][i].append(distmat[0, i])
# sort the candidate according to the distance
candidates[query_track_id].sort(key=operator.itemgetter(-1))
for track_id in remove_track_ids:
del candidates[track_id]
if args.debug:
print("removed %s candidates due to bad boxes" % len(remove_track_ids))
get_features_and_compare(p_candidates, p_extractor, method=compare_method)
get_features_and_compare(v_candidates, v_extractor, method=compare_method)
def merge_candidates(candidates, track_data):
merge_map = {} # track_id -> track_id
# go through the candidate and decide merging or not
# 1. remove candidates with large dist
# 2. select longest one if muliple in reverse
# 3. return new track_data
reverse_candidates = {} # track_id -> list of track_id
for query_track_id in candidates:
match_list = [(o[0], o[-1])
for o in candidates[query_track_id]
if o[-1] < reject_dist_thres]
if match_list:
# top 1 is used
match_track_id, dist = match_list[0]
if match_track_id not in reverse_candidates:
reverse_candidates[match_track_id] = []
reverse_candidates[match_track_id].append((
query_track_id, dist, len(track_data[query_track_id])))
for track_id in reverse_candidates:
if len(reverse_candidates[track_id]) > 1:
reverse_candidates[track_id].sort(key=operator.itemgetter(2),
reverse=True)
matched_prev_track_id, _, _ = reverse_candidates[track_id][0]
# find the earliest track_id
while matched_prev_track_id in merge_map:
matched_prev_track_id = merge_map[matched_prev_track_id]
assert track_id not in merge_map
assert matched_prev_track_id not in merge_map
merge_map[track_id] = matched_prev_track_id
if args.debug:
print(candidates, merge_map)
new_track_data = {}
for track_id in track_data:
this_track_data = track_data[track_id][:, :] # [K, 7]
if track_id in merge_map:
track_id = merge_map[track_id]
this_track_data[:, 1] = track_id
if track_id in new_track_data:
new_track_data[track_id] = np.concatenate(
[new_track_data[track_id], this_track_data], axis=0)
else:
new_track_data[track_id] = this_track_data
return new_track_data
p_track_data = merge_candidates(p_candidates, p_track_data)
v_track_data = merge_candidates(v_candidates, v_track_data)
return p_track_data, v_track_data
def save_new_track(cat_name, track_data, out_dir, videoname):
track_results = sorted(
[l.tolist() for t in track_data for l in track_data[t]],
key=lambda x: (x[0], x[1]))
out_file_dir = os.path.join(out_dir, videoname, cat_name)
if not os.path.exists(out_file_dir):
os.makedirs(out_file_dir)
out_file = os.path.join(
out_file_dir, os.path.splitext(videoname)[0] + ".txt")
with open(out_file, "w") as fw:
for row in track_results:
line = "%d,%d,%.2f,%.2f,%.2f,%.2f,%.4f,-1,-1,-1" % (
row[0], row[1], row[2], row[3], row[4], row[5], row[6])
fw.write(line + "\n")
if __name__ == "__main__":
args = parser.parse_args()
# still has the .mp4/.avi
videonames = [os.path.basename(line.strip())
for line in open(args.videonamelst, "r").readlines()]
if not os.path.exists(args.newfilepath):
os.makedirs(args.newfilepath)
if args.person_reid_model is None or args.vehicle_reid_model is None:
raise Exception("Please provide models for person and vehicle!")
# assuming your GPU can fit both model at once
person_reid_extractor = FeatureExtractor(
model_name="osnet_x1_0",
model_path=args.person_reid_model,
image_size=(256, 128), # (h, w)
device="cuda:%d" % args.gpuid
)
vehicle_reid_extractor = FeatureExtractor(
model_name="resnet101",
model_path=args.vehicle_reid_model,
image_size=(128, 256),
device="cuda:%d" % args.gpuid
)
print("Model loaded.")
total_p_track = [0, 0] # before / after
total_v_track = [0, 0]
for videoname in tqdm(videonames):
clip_name = os.path.splitext(videoname)[0]
if args.debug and clip_name not in [
"2018-03-11.11-35-01.11-40-01.school.G424",
"2018-03-05.13-10-00.13-15-00.bus.G340",
"2018-03-07.17-25-06.17-30-06.school.G339"]:
continue
date, hr_slot, _ = parse_meva_clip_name(clip_name)
video_path = os.path.join(args.video_path, date, hr_slot, videoname)
# start reading video frames now!
if args.use_lijun_video_loader:
vcap = VideoReader(video_path)
frame_count = int(vcap.length)
elif args.use_moviepy:
vcap = VideoFileClip(video_path, audio=False)
frame_count = int(vcap.fps * vcap.duration) # uh
vcap = vcap.iter_frames()
else:
try:
vcap = cv2.VideoCapture(video_path)
if not vcap.isOpened():
raise Exception("cannot open %s" % video_path)
except Exception as e:
# raise e
# just move on to the next video
print("warning, cannot open %s" % video_path)
continue
# opencv 3/4
frame_count = vcap.get(cv2.CAP_PROP_FRAME_COUNT)
# start reading frames into queues now
video_queuer = VideoEnqueuer(
args, vcap, frame_count, frame_gap=1, # no skipping frames
prefetch=100,
start=True, is_moviepy=args.use_moviepy,
batch_size=1)
get_frame_batches = video_queuer.get()
# 1. read the tracklets and identify potential matching tracklets
# for each query tracklet
person_track_file = os.path.join(
args.filepath, videoname, "Person", "%s.txt" % (
os.path.splitext(videoname)[0]))
vehicle_track_file = os.path.join(
args.filepath, videoname, "Vehicle", "%s.txt" % (
os.path.splitext(videoname)[0]))
if not os.path.exists(person_track_file) or not os.path.exists(vehicle_track_file):
tqdm.write("warning, skipping %s due to track file not exists" % clip_name)
if not args.use_lijun_video_loader and not args.use_moviepy:
vcap.release()
video_queuer.stop()
continue
ori_p_track_num, p_track_data, p_candidates = preprocess(
person_track_file, args.tol_num_frame,
args.expand_width_p, args.expand_height_p)
ori_v_track_num, v_track_data, v_candidates = preprocess(
vehicle_track_file, args.tol_num_frame,
args.expand_width_p, args.expand_height_p)
# 2. compute similarities between query tracklet and the candidates. It is
# a match if below threshold
new_p_track_data, new_v_track_data = reid(
(get_frame_batches, video_queuer.num_batches),
person_reid_extractor, vehicle_reid_extractor,
p_track_data, p_candidates, v_track_data, v_candidates,
feature_box_num=args.feature_box_num,
feature_box_gap=args.feature_box_gap,
reject_dist_thres=args.reject_dist_thres,
compare_method=args.compare_method)
if args.debug:
print("person track %s -> %s" % (
ori_p_track_num, len(new_p_track_data)))
print("vehicle track %s -> %s" % (
ori_v_track_num, len(new_v_track_data)))
# save new file
save_new_track("Person", new_p_track_data, args.newfilepath, videoname)
save_new_track("Vehicle", new_v_track_data, args.newfilepath, videoname)
total_p_track[0] += ori_p_track_num
total_p_track[1] += len(new_p_track_data)
total_v_track[0] += ori_v_track_num
total_v_track[1] += len(new_v_track_data)
if not args.use_lijun_video_loader and not args.use_moviepy:
vcap.release()
video_queuer.stop()
print("total person track %s -> %s" % (total_p_track[0], total_p_track[1]))
print("total vehicle track %s -> %s" % (total_v_track[0], total_v_track[1]))