Skip to content

Commit

Permalink
Merge remote-tracking branch 'OpenPCDet-fnozarian/DA-52-cos-score-wei…
Browse files Browse the repository at this point in the history
…ghting-pooled-refactor' into DA-52-cos-score-weighting-pooled-refactor
  • Loading branch information
fnozarian committed Sep 3, 2023
2 parents 87bfa2c + 3889d9d commit 0685d17
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 3 deletions.
4 changes: 3 additions & 1 deletion pcdet/models/roi_heads/roi_head_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def proposal_layer(self, batch_dict, nms_config):
rois = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE, batch_box_preds.shape[-1]))
roi_scores = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE))
roi_labels = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE), dtype=torch.long)

roi_scores_multiclass = batch_box_preds.new_zeros((batch_size, nms_config.NMS_POST_MAXSIZE, batch_cls_preds.shape[-1]))
for index in range(batch_size):
if batch_dict.get('batch_index', None) is not None:
assert batch_cls_preds.shape.__len__() == 2
Expand All @@ -142,9 +142,11 @@ def proposal_layer(self, batch_dict, nms_config):
rois[index, :len(selected), :] = box_preds[selected]
roi_scores[index, :len(selected)] = cur_roi_scores[selected]
roi_labels[index, :len(selected)] = cur_roi_labels[selected]
roi_scores_multiclass[index, :len(selected), :] = cls_preds[selected]

batch_dict['rois'] = rois
batch_dict['roi_scores'] = roi_scores
batch_dict['roi_scores_multiclass'] = roi_scores_multiclass
batch_dict['roi_labels'] = roi_labels + 1
batch_dict['has_class_labels'] = True if batch_cls_preds.shape[-1] > 1 else False
batch_dict.pop('batch_index', None)
Expand Down
4 changes: 2 additions & 2 deletions tools/cfgs/kitti_models/pv_rcnn_ssl_60.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ MODEL:
ENABLE_VIS: False
ENABLE_PROTO_CONTRASTIVE_LOSS: False
PROTO_CONTRASTIVE_LOSS_WEIGHT: 1.0
ENABLE_SOFT_TEACHER: True
ENABLE_SOFT_TEACHER: False
ENABLE_ULB_CLS_DIST_LOSS: False
ENABLE_EVAL: True
METRICS_PRED_TYPES: [roi_pl_gt, pl_gt_metrics_before_filtering]
Expand Down Expand Up @@ -221,7 +221,7 @@ MODEL:

REG_FG_THRESH: 0.55
UNLABELED_REG_FG_THRESH: [0.55, 0.55, 0.55]
UNLABELED_SAMPLER_TYPE: subsample_unlabeled_rois_default
UNLABELED_SAMPLER_TYPE: subsample_labeled_rois #subsample_unlabeled_rois_default
UNLABELED_SAMPLE_EASY_BG: False
UNLABELED_SHARP_RCNN_CLS_LABELS: True
UNLABELED_USE_CALIBRATED_IOUS: True
Expand Down

0 comments on commit 0685d17

Please sign in to comment.