-
Notifications
You must be signed in to change notification settings - Fork 158
/
Copy pathvot_tool.py
205 lines (155 loc) · 6.7 KB
/
vot_tool.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
"""
\file vot.py
@brief Python utility functions for VOT toolkit integration
@author Luka Cehovin, Alessio Dore
@date 2023
"""
import os
import collections
import numpy as np
try:
import trax
except ImportError:
raise Exception('TraX support not found. Please add trax module to Python path.')
Rectangle = collections.namedtuple('Rectangle', ['x', 'y', 'width', 'height'])
Point = collections.namedtuple('Point', ['x', 'y'])
Polygon = collections.namedtuple('Polygon', ['points'])
Empty = collections.namedtuple('Empty', [])
class VOT(object):
""" Base class for VOT toolkit integration in Python.
This class is only a wrapper around the TraX protocol and can be used for single or multi-object tracking.
The wrapper assumes that the experiment will provide new objects onlf at the first frame and will fail otherwise."""
def __init__(self, region_format, channels=None, multiobject: bool = None):
""" Constructor for the VOT wrapper.
Args:
region_format: Region format options
channels: Channels that are supported by the tracker
multiobject: Whether to use multi-object tracking
"""
assert(region_format in [trax.Region.RECTANGLE, trax.Region.POLYGON, trax.Region.MASK])
if multiobject is None:
multiobject = os.environ.get('VOT_MULTI_OBJECT', '0') == '1'
if channels is None:
channels = ['color']
elif channels == 'rgbd':
channels = ['color', 'depth']
elif channels == 'rgbt':
channels = ['color', 'ir']
elif channels == 'ir':
channels = ['ir']
else:
raise Exception('Illegal configuration {}.'.format(channels))
self._trax = trax.Server([region_format], [trax.Image.PATH], channels, metadata=dict(vot="python"), multiobject=multiobject)
request = self._trax.wait()
assert(request.type == 'initialize')
self._objects = []
assert len(request.objects) > 0 and (multiobject or len(request.objects) == 1)
for object, _ in request.objects:
if isinstance(object, trax.Polygon):
self._objects.append(Polygon([Point(x[0], x[1]) for x in object]))
elif isinstance(object, trax.Mask):
self._objects.append(object.array(True))
else:
self._objects.append(Rectangle(*object.bounds()))
self._image = [x.path() for k, x in request.image.items()]
if len(self._image) == 1:
self._image = self._image[0]
self._multiobject = multiobject
self._trax.status(request.objects)
def region(self):
"""
Returns initialization region for the first frame in single object tracking mode.
Returns:
initialization region
"""
assert not self._multiobject
return self._objects[0]
def objects(self):
"""
Returns initialization regions for the first frame in multi object tracking mode.
Returns:
initialization regions for all objects
"""
return self._objects
def report(self, status, confidence = None):
"""
Report the tracking results to the client
Arguments:
status: region for the frame or a list of regions in case of multi object tracking
confidence: confidence for the object detection, used only in single object tracking mode
"""
def convert(region):
""" Convert region to TraX format """
# If region is None, return empty region
if region is None: return trax.Rectangle.create(0, 0, 0, 0)
assert isinstance(region, (Empty, Rectangle, Polygon, np.ndarray))
if isinstance(region, Empty):
return trax.Rectangle.create(0, 0, 0, 0)
elif isinstance(region, Polygon):
return trax.Polygon.create([(x.x, x.y) for x in region.points])
elif isinstance(region, np.ndarray):
return trax.Mask.create(region)
else:
return trax.Rectangle.create(region.x, region.y, region.width, region.height)
if not self._multiobject:
status = convert(status)
else:
assert isinstance(status, (list, tuple))
status = [(convert(x), {}) for x in status]
properties = {}
if not confidence is None and not self._multiobject:
properties['confidence'] = confidence
self._trax.status(status, properties)
def frame(self):
"""
Get a frame (image path) from client
Returns:
absolute path of the image
"""
if hasattr(self, "_image"):
image = self._image
del self._image
return image
request = self._trax.wait()
# Only the first frame can declare new objects for now
assert request.objects is None or len(request.objects) == 0
if request.type == 'frame':
image = [x.path() for k, x in request.image.items()]
if len(image) == 1:
return image[0]
return image
else:
return None
def quit(self):
""" Quit the tracker"""
if hasattr(self, '_trax'):
self._trax.quit()
def __del__(self):
""" Destructor for the tracker, calls quit. """
self.quit()
class VOTManager(object):
""" VOT Manager is provides a simple interface for running multiple single object trackers in parallel. Trackers should implement a factory interface. """
def __init__(self, factory, region_format, channels=None):
""" Constructor for the manager.
The factory should be a callable that accepts two arguments: image and region and returns a callable that accepts a single argument (image) and returns a region.
Args:
factory: Factory function for creating trackers
region_format: Region format options
channels: Channels that are supported by the tracker
"""
self._handle = VOT(region_format, channels, multiobject=True)
self._factory = factory
def run(self):
""" Run the tracker, the tracking loop is implemented in this function, so it will block until the client terminates the connection."""
objects = self._handle.objects()
# Process the first frame
image = self._handle.frame()
if not image:
return
trackers = [self._factory(image, object) for object in objects]
while True:
image = self._handle.frame()
if not image:
break
status = [tracker(image) for tracker in trackers]
self._handle.report(status)