diff --git a/mmdet2trt/apis/inference.py b/mmdet2trt/apis/inference.py index 6ac84aa..93e3434 100644 --- a/mmdet2trt/apis/inference.py +++ b/mmdet2trt/apis/inference.py @@ -182,19 +182,20 @@ def forward(self, img, img_metas, *args, **kwargs): masks = masks.detach().cpu().numpy() num_classes = len(self.CLASSES) class_agnostic = True - segms_results = [] - for i in range(batch_size): - segms_results = FCNMaskHead.get_seg_masks( - Addict( - num_classes=num_classes, - class_agnostic=class_agnostic), - masks, - old_dets, - labels, - rcnn_test_cfg=Addict(mask_thr_binary=0.5), - ori_shape=img_metas[i]['ori_shape'], - scale_factor=scale_factor, - rescale=rescale) + segms_results = [[] for _ in range(num_classes)] + if num_dets>0: + for i in range(batch_size): + segms_results = FCNMaskHead.get_seg_masks( + Addict( + num_classes=num_classes, + class_agnostic=class_agnostic), + masks, + old_dets, + labels, + rcnn_test_cfg=Addict(mask_thr_binary=0.5), + ori_shape=img_metas[i]['ori_shape'], + scale_factor=scale_factor, + rescale=rescale) results.append((dets_results, segms_results)) else: results.append(dets_results)