Skip to content

Commit

Permalink
fix bug when mask number == 0 (#114)
Browse files Browse the repository at this point in the history
  • Loading branch information
chenxinfeng4 authored May 29, 2022
1 parent 536b8d9 commit cca99f6
Showing 1 changed file with 14 additions and 13 deletions.
27 changes: 14 additions & 13 deletions mmdet2trt/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,19 +182,20 @@ def forward(self, img, img_metas, *args, **kwargs):
masks = masks.detach().cpu().numpy()
num_classes = len(self.CLASSES)
class_agnostic = True
segms_results = []
for i in range(batch_size):
segms_results = FCNMaskHead.get_seg_masks(
Addict(
num_classes=num_classes,
class_agnostic=class_agnostic),
masks,
old_dets,
labels,
rcnn_test_cfg=Addict(mask_thr_binary=0.5),
ori_shape=img_metas[i]['ori_shape'],
scale_factor=scale_factor,
rescale=rescale)
segms_results = [[] for _ in range(num_classes)]
if num_dets>0:
for i in range(batch_size):
segms_results = FCNMaskHead.get_seg_masks(
Addict(
num_classes=num_classes,
class_agnostic=class_agnostic),
masks,
old_dets,
labels,
rcnn_test_cfg=Addict(mask_thr_binary=0.5),
ori_shape=img_metas[i]['ori_shape'],
scale_factor=scale_factor,
rescale=rescale)
results.append((dets_results, segms_results))
else:
results.append(dets_results)
Expand Down

0 comments on commit cca99f6

Please sign in to comment.