-
Notifications
You must be signed in to change notification settings - Fork 109
/
rpn.py
539 lines (463 loc) · 21.2 KB
/
rpn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
import torch
import torchvision
from torch import Tensor, nn
from torch.jit.annotations import Dict, List, Optional, Tuple
from torch.nn import functional as F
from torchvision.models.detection import _utils as det_utils
from torchvision.models.detection.image_list import ImageList
from torchvision.ops import boxes as box_ops
@torch.jit.unused
def _onnx_get_num_anchors_and_pre_nms_top_n(ob, orig_pre_nms_top_n):
# type: (Tensor, int) -> Tuple[int, int]
from torch.onnx import operators
num_anchors = operators.shape_as_tensor(ob)[1].unsqueeze(0)
pre_nms_top_n = torch.min(
torch.cat(
(torch.tensor([orig_pre_nms_top_n], dtype=num_anchors.dtype), num_anchors),
0,
)
)
return num_anchors, pre_nms_top_n
class AnchorGenerator(nn.Module):
__annotations__ = {
"cell_anchors": Optional[List[torch.Tensor]],
"_cache": Dict[str, List[torch.Tensor]],
}
"""
Module that generates anchors for a set of feature maps and
image sizes.
The module support computing anchors at multiple sizes and aspect ratios
per feature map. This module assumes aspect ratio = height / width for
each anchor.
sizes and aspect_ratios should have the same number of elements, and it should
correspond to the number of feature maps.
sizes[i] and aspect_ratios[i] can have an arbitrary number of elements,
and AnchorGenerator will output a set of sizes[i] * aspect_ratios[i] anchors
per spatial location for feature map i.
Arguments:
sizes (Tuple[Tuple[int]]):
aspect_ratios (Tuple[Tuple[float]]):
"""
def __init__(
self,
sizes=(128, 256, 512),
aspect_ratios=(0.5, 1.0, 2.0),
):
super(AnchorGenerator, self).__init__()
if not isinstance(sizes[0], (list, tuple)):
# TODO change this
sizes = tuple((s,) for s in sizes)
if not isinstance(aspect_ratios[0], (list, tuple)):
aspect_ratios = (aspect_ratios,) * len(sizes)
assert len(sizes) == len(aspect_ratios)
self.sizes = sizes
self.aspect_ratios = aspect_ratios
self.cell_anchors = None
self._cache = {}
# TODO: https://github.com/pytorch/pytorch/issues/26792
# For every (aspect_ratios, scales) combination, output a zero-centered
# anchor with those values.
# (scales, aspect_ratios) are usually an element of
# zip(self.scales, self.aspect_ratios)
# This method assumes aspect ratio = height / width for an anchor.
def generate_anchors(
self, scales, aspect_ratios, dtype=torch.float32, device="cpu"
):
# type: (List[int], List[float], int, Device) -> Tensor # noqa: F821
scales = torch.as_tensor(scales, dtype=dtype, device=device)
aspect_ratios = torch.as_tensor(aspect_ratios, dtype=dtype, device=device)
h_ratios = torch.sqrt(aspect_ratios)
w_ratios = 1 / h_ratios
ws = (w_ratios[:, None] * scales[None, :]).view(-1)
hs = (h_ratios[:, None] * scales[None, :]).view(-1)
base_anchors = torch.stack([-ws, -hs, ws, hs], dim=1) / 2
return base_anchors.round()
def set_cell_anchors(self, dtype, device):
# type: (int, Device) -> None # noqa: F821
if self.cell_anchors is not None:
cell_anchors = self.cell_anchors
assert cell_anchors is not None
# suppose that all anchors have the same device
# which is a valid assumption in the current state of the codebase
if cell_anchors[0].device == device:
return
cell_anchors = [
self.generate_anchors(sizes, aspect_ratios, dtype, device)
for sizes, aspect_ratios in zip(self.sizes, self.aspect_ratios)
]
self.cell_anchors = cell_anchors
def num_anchors_per_location(self):
return [len(s) * len(a) for s, a in zip(self.sizes, self.aspect_ratios)]
# For every combination of (a, (g, s), i) in
# (self.cell_anchors, zip(grid_sizes, strides), 0:2),
# output g[i] anchors that are s[i] distance apart in direction i,
# with the same dimensions as a.
def grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
anchors = []
cell_anchors = self.cell_anchors
assert cell_anchors is not None
for size, stride, base_anchors in zip(grid_sizes, strides, cell_anchors):
grid_height, grid_width = size
stride_height, stride_width = stride
device = base_anchors.device
# For output anchor, compute [x_center, y_center, x_center, y_center]
shifts_x = (
torch.arange(0, grid_width, dtype=torch.float32, device=device)
* stride_width
)
shifts_y = (
torch.arange(0, grid_height, dtype=torch.float32, device=device)
* stride_height
)
shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x)
shift_x = shift_x.reshape(-1)
shift_y = shift_y.reshape(-1)
shifts = torch.stack((shift_x, shift_y, shift_x, shift_y), dim=1)
# For every (base anchor, output anchor) pair,
# offset each zero-centered base anchor by the center of the output anchor.
anchors.append(
(shifts.view(-1, 1, 4) + base_anchors.view(1, -1, 4)).reshape(-1, 4)
)
return anchors
def cached_grid_anchors(self, grid_sizes, strides):
# type: (List[List[int]], List[List[Tensor]]) -> List[Tensor]
key = str(grid_sizes) + str(strides)
if key in self._cache:
return self._cache[key]
anchors = self.grid_anchors(grid_sizes, strides)
self._cache[key] = anchors
return anchors
def forward(self, image_list, feature_maps):
# type: (ImageList, List[Tensor]) -> List[Tensor]
grid_sizes = list([feature_map.shape[-2:] for feature_map in feature_maps])
image_size = image_list.tensors.shape[-2:]
dtype, device = feature_maps[0].dtype, feature_maps[0].device
strides = [
[
torch.tensor(image_size[0] // g[0], dtype=torch.int64, device=device),
torch.tensor(image_size[1] // g[1], dtype=torch.int64, device=device),
]
for g in grid_sizes
]
self.set_cell_anchors(dtype, device)
anchors_over_all_feature_maps = self.cached_grid_anchors(grid_sizes, strides)
anchors = torch.jit.annotate(List[List[torch.Tensor]], [])
for i, (image_height, image_width) in enumerate(image_list.image_sizes):
anchors_in_image = []
for anchors_per_feature_map in anchors_over_all_feature_maps:
anchors_in_image.append(anchors_per_feature_map)
anchors.append(anchors_in_image)
anchors = [torch.cat(anchors_per_image) for anchors_per_image in anchors]
# Clear the cache in case that memory leaks.
self._cache.clear()
return anchors
class RPNHead(nn.Module):
"""
Adds a simple RPN Head with classification and regression heads
Arguments:
in_channels (int): number of channels of the input feature
num_anchors (int): number of anchors to be predicted
"""
def __init__(self, in_channels, num_anchors):
super(RPNHead, self).__init__()
self.conv = nn.Conv2d(
in_channels, in_channels, kernel_size=3, stride=1, padding=1
)
self.cls_logits = nn.Conv2d(in_channels, num_anchors, kernel_size=1, stride=1)
self.bbox_pred = nn.Conv2d(
in_channels, num_anchors * 4, kernel_size=1, stride=1
)
for layer in self.children():
torch.nn.init.normal_(layer.weight, std=0.01)
torch.nn.init.constant_(layer.bias, 0)
def forward(self, x):
# type: (List[Tensor]) -> Tuple[List[Tensor], List[Tensor]]
logits = []
bbox_reg = []
for feature in x:
t = F.relu(self.conv(feature))
logits.append(self.cls_logits(t))
bbox_reg.append(self.bbox_pred(t))
return logits, bbox_reg
def permute_and_flatten(layer, N, A, C, H, W):
# type: (Tensor, int, int, int, int, int) -> Tensor
layer = layer.view(N, -1, C, H, W)
layer = layer.permute(0, 3, 4, 1, 2)
layer = layer.reshape(N, -1, C)
return layer
def concat_box_prediction_layers(box_cls, box_regression):
# type: (List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
box_cls_flattened = []
box_regression_flattened = []
# for each feature level, permute the outputs to make them be in the
# same format as the labels. Note that the labels are computed for
# all feature levels concatenated, so we keep the same representation
# for the objectness and the box_regression
for box_cls_per_level, box_regression_per_level in zip(box_cls, box_regression):
N, AxC, H, W = box_cls_per_level.shape
Ax4 = box_regression_per_level.shape[1]
A = Ax4 // 4
C = AxC // A
box_cls_per_level = permute_and_flatten(box_cls_per_level, N, A, C, H, W)
box_cls_flattened.append(box_cls_per_level)
box_regression_per_level = permute_and_flatten(
box_regression_per_level, N, A, 4, H, W
)
box_regression_flattened.append(box_regression_per_level)
# concatenate on the first dimension (representing the feature levels), to
# take into account the way the labels were generated (with all feature maps
# being concatenated as well)
box_cls = torch.cat(box_cls_flattened, dim=1).flatten(0, -2)
box_regression = torch.cat(box_regression_flattened, dim=1).reshape(-1, 4)
return box_cls, box_regression
class RegionProposalNetwork(torch.nn.Module):
"""
Implements Region Proposal Network (RPN).
Arguments:
anchor_generator (AnchorGenerator): module that generates the anchors
for a set of feature maps.
head (nn.Module): module that computes the objectness and regression deltas
fg_iou_thresh (float): minimum IoU between the anchor and the GT box so
that they can be considered as positive during training of the RPN.
bg_iou_thresh (float): maximum IoU between the anchor and the GT box so
that they can be considered as negative during training of the RPN.
batch_size_per_image (int): number of anchors that are sampled during
training of the RPN for computing the loss
positive_fraction (float): proportion of positive anchors in a mini-batch
during training of the RPN
pre_nms_top_n (Dict[int]): number of proposals to keep before applying NMS.
It should contain two fields: training and testing, to allow for different
values depending on training or evaluation
post_nms_top_n (Dict[int]): number of proposals to keep after applying NMS.
It should contain two fields: training and testing, to allow for different
values depending on training or evaluation
nms_thresh (float): NMS threshold used for postprocessing the RPN proposals
"""
__annotations__ = {
"box_coder": det_utils.BoxCoder,
"proposal_matcher": det_utils.Matcher,
"fg_bg_sampler": det_utils.BalancedPositiveNegativeSampler,
"pre_nms_top_n": Dict[str, int],
"post_nms_top_n": Dict[str, int],
}
def __init__(
self,
anchor_generator,
head,
#
fg_iou_thresh,
bg_iou_thresh,
batch_size_per_image,
positive_fraction,
#
pre_nms_top_n,
post_nms_top_n,
nms_thresh,
):
super(RegionProposalNetwork, self).__init__()
self.anchor_generator = anchor_generator
self.head = head
self.box_coder = det_utils.BoxCoder(weights=(1.0, 1.0, 1.0, 1.0))
# used during training
self.box_similarity = box_ops.box_iou
self.proposal_matcher = det_utils.Matcher(
fg_iou_thresh,
bg_iou_thresh,
allow_low_quality_matches=True,
)
self.fg_bg_sampler = det_utils.BalancedPositiveNegativeSampler(
batch_size_per_image, positive_fraction
)
# used during testing
self._pre_nms_top_n = pre_nms_top_n
self._post_nms_top_n = post_nms_top_n
self.nms_thresh = nms_thresh
self.min_size = 1e-3
def pre_nms_top_n(self):
if self.training:
return self._pre_nms_top_n["training"]
return self._pre_nms_top_n["testing"]
def post_nms_top_n(self):
if self.training:
return self._post_nms_top_n["training"]
return self._post_nms_top_n["testing"]
def assign_targets_to_anchors(self, anchors, targets):
# type: (List[Tensor], List[Dict[str, Tensor]])
# -> Tuple[List[Tensor], List[Tensor]]
labels = []
matched_gt_boxes = []
for anchors_per_image, targets_per_image in zip(anchors, targets):
gt_boxes = targets_per_image["boxes"]
if gt_boxes.numel() == 0:
# Background image (negative example)
device = anchors_per_image.device
matched_gt_boxes_per_image = torch.zeros(
anchors_per_image.shape, dtype=torch.float32, device=device
)
labels_per_image = torch.zeros(
(anchors_per_image.shape[0],), dtype=torch.float32, device=device
)
else:
match_quality_matrix = box_ops.box_iou(gt_boxes, anchors_per_image)
matched_idxs = self.proposal_matcher(match_quality_matrix)
# get the targets corresponding GT for each proposal
# NB: need to clamp the indices because we can have a single
# GT in the image, and matched_idxs can be -2, which goes
# out of bounds
matched_gt_boxes_per_image = gt_boxes[matched_idxs.clamp(min=0)]
labels_per_image = matched_idxs >= 0
labels_per_image = labels_per_image.to(dtype=torch.float32)
# Background (negative examples)
bg_indices = matched_idxs == self.proposal_matcher.BELOW_LOW_THRESHOLD
labels_per_image[bg_indices] = 0.0
# discard indices that are between thresholds
inds_to_discard = (
matched_idxs == self.proposal_matcher.BETWEEN_THRESHOLDS
)
labels_per_image[inds_to_discard] = -1.0
labels.append(labels_per_image)
matched_gt_boxes.append(matched_gt_boxes_per_image)
return labels, matched_gt_boxes
def _get_top_n_idx(self, objectness, num_anchors_per_level):
# type: (Tensor, List[int]) -> Tensor
r = []
offset = 0
for ob in objectness.split(num_anchors_per_level, 1):
if torchvision._is_tracing():
num_anchors, pre_nms_top_n = _onnx_get_num_anchors_and_pre_nms_top_n(
ob, self.pre_nms_top_n()
)
else:
num_anchors = ob.shape[1]
pre_nms_top_n = min(self.pre_nms_top_n(), num_anchors)
_, top_n_idx = ob.topk(pre_nms_top_n, dim=1)
r.append(top_n_idx + offset)
offset += num_anchors
return torch.cat(r, dim=1)
def filter_proposals(
self, proposals, objectness, image_shapes, num_anchors_per_level
):
# type: (Tensor, Tensor, List[Tuple[int, int]], List[int])
# -> Tuple[List[Tensor], List[Tensor]]
num_images = proposals.shape[0]
device = proposals.device
# do not backprop throught objectness
objectness = objectness.detach()
objectness = objectness.reshape(num_images, -1)
levels = [
torch.full((n,), idx, dtype=torch.int64, device=device)
for idx, n in enumerate(num_anchors_per_level)
]
levels = torch.cat(levels, 0)
levels = levels.reshape(1, -1).expand_as(objectness)
# select top_n boxes independently per level before applying nms
top_n_idx = self._get_top_n_idx(objectness, num_anchors_per_level)
image_range = torch.arange(num_images, device=device)
batch_idx = image_range[:, None]
objectness = objectness[batch_idx, top_n_idx]
levels = levels[batch_idx, top_n_idx]
proposals = proposals[batch_idx, top_n_idx]
final_boxes = []
final_scores = []
for boxes, scores, lvl, img_shape in zip(
proposals, objectness, levels, image_shapes
):
boxes = box_ops.clip_boxes_to_image(boxes, img_shape)
keep = box_ops.remove_small_boxes(boxes, self.min_size)
boxes, scores, lvl = boxes[keep], scores[keep], lvl[keep]
# non-maximum suppression, independently done per level
keep = box_ops.batched_nms(boxes, scores, lvl, self.nms_thresh)
# keep only topk scoring predictions
keep = keep[: self.post_nms_top_n()]
boxes, scores = boxes[keep], scores[keep]
final_boxes.append(boxes)
final_scores.append(scores)
return final_boxes, final_scores
def compute_loss(self, objectness, pred_bbox_deltas, labels, regression_targets):
# type: (Tensor, Tensor, List[Tensor], List[Tensor]) -> Tuple[Tensor, Tensor]
"""
Arguments:
objectness (Tensor)
pred_bbox_deltas (Tensor)
labels (List[Tensor])
regression_targets (List[Tensor])
Returns:
objectness_loss (Tensor)
box_loss (Tensor)
"""
sampled_pos_inds, sampled_neg_inds = self.fg_bg_sampler(labels)
sampled_pos_inds = torch.nonzero(torch.cat(sampled_pos_inds, dim=0)).squeeze(1)
sampled_neg_inds = torch.nonzero(torch.cat(sampled_neg_inds, dim=0)).squeeze(1)
sampled_inds = torch.cat([sampled_pos_inds, sampled_neg_inds], dim=0)
objectness = objectness.flatten()
labels = torch.cat(labels, dim=0)
regression_targets = torch.cat(regression_targets, dim=0)
box_loss = (
det_utils.smooth_l1_loss(
pred_bbox_deltas[sampled_pos_inds],
regression_targets[sampled_pos_inds],
beta=1 / 9,
size_average=False,
)
/ (sampled_inds.numel())
)
objectness_loss = F.binary_cross_entropy_with_logits(
objectness[sampled_inds], labels[sampled_inds]
)
return objectness_loss, box_loss
def forward(
self,
images, # type: ImageList
features, # type: Dict[str, Tensor]
targets=None, # type: Optional[List[Dict[str, Tensor]]]
):
# type: (...) -> Tuple[List[Tensor], Dict[str, Tensor]]
"""
Arguments:
images (ImageList): images for which we want to compute the predictions
features (OrderedDict[Tensor]): features computed from the images that are
used for computing the predictions. Each tensor in the list
correspond to different feature levels
targets (List[Dict[Tensor]]): ground-truth present in the image (optional).
If provided, each element in the dict should contain a field `boxes`,
with the locations of the ground-truth boxes.
Returns:
boxes (List[Tensor]): the predicted boxes from the RPN, one Tensor per
image.
losses (Dict[Tensor]): the losses for the model during training. During
testing, it is an empty dict.
"""
# RPN uses all feature maps that are available
features = list(features.values())
objectness, pred_bbox_deltas = self.head(features)
anchors = self.anchor_generator(images, features)
num_images = len(anchors)
num_anchors_per_level_shape_tensors = [o[0].shape for o in objectness]
num_anchors_per_level = [
s[0] * s[1] * s[2] for s in num_anchors_per_level_shape_tensors
]
objectness, pred_bbox_deltas = concat_box_prediction_layers(
objectness, pred_bbox_deltas
)
# apply pred_bbox_deltas to anchors to obtain the decoded proposals
# note that we detach the deltas because Faster R-CNN do not backprop through
# the proposals
proposals = self.box_coder.decode(pred_bbox_deltas.detach(), anchors)
proposals = proposals.view(num_images, -1, 4)
boxes, scores = self.filter_proposals(
proposals, objectness, images.image_sizes, num_anchors_per_level
)
losses = {}
if self.training or targets is not None:
assert targets is not None
labels, matched_gt_boxes = self.assign_targets_to_anchors(anchors, targets)
regression_targets = self.box_coder.encode(matched_gt_boxes, anchors)
loss_objectness, loss_rpn_box_reg = self.compute_loss(
objectness, pred_bbox_deltas, labels, regression_targets
)
losses = {
"loss_objectness": loss_objectness,
"loss_rpn_box_reg": loss_rpn_box_reg,
}
return boxes, losses