Skip to content

Commit

Permalink
Format pv_rcnn_ssl.py and adamatch.py.
Browse files Browse the repository at this point in the history
  • Loading branch information
fnozarian committed Sep 3, 2023
1 parent 1885ef1 commit 87bfa2c
Show file tree
Hide file tree
Showing 2 changed files with 213 additions and 196 deletions.
21 changes: 9 additions & 12 deletions pcdet/models/detectors/pv_rcnn_ssl.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
from tools.visual_utils import open3d_vis_utils as V
from collections import defaultdict
from pcdet.utils.weighting_methods import build_thresholding_method
from visual_utils import open3d_vis_utils as V


class DynamicThreshRegistry(object):
def __init__(self, **kwargs):
Expand Down Expand Up @@ -134,17 +132,17 @@ def forward(self, batch_dict):

def _rectify_pl_scores(self, batch_dict_ema, unlabeled_inds):
thresh_reg = self.thresh_registry.get(tag='pl_adaptive_thresh')
pred_weak_aug_before_nms = torch.sigmoid(batch_dict_ema['batch_cls_preds']).detach().clone()
pred_wa = torch.sigmoid(batch_dict_ema['batch_cls_preds']).detach().clone()
# to be used later for updating the EMA (p_model/p_target)
pred_weak_aug_before_nms_org = pred_weak_aug_before_nms.clone()
pred_weak_aug_before_nms_org = pred_wa.clone()
if thresh_reg.iteration_count > 0:
pred_weak_aug_unlab_before_nms = pred_weak_aug_before_nms[unlabeled_inds, ...]
pred_weak_aug_unlab_before_nms_aligned = pred_weak_aug_unlab_before_nms * (thresh_reg.ema_pred_weak_aug_lab_before_nms + 1e-6) / (thresh_reg.ema_pred_weak_aug_unlab_before_nms + 1e-6)
pred_weak_aug_unlab_before_nms_aligned = thresh_reg.normalize_(pred_weak_aug_unlab_before_nms_aligned)
pred_weak_aug_before_nms[unlabeled_inds, ...] = pred_weak_aug_unlab_before_nms_aligned
pred_wa_ulb = pred_wa[unlabeled_inds, ...]
pred_wa_ulb_aligned = pred_wa_ulb * thresh_reg.ema_pred_wa_lab / (thresh_reg.ema_pred_wa_ulb + 1e-6)
pred_wa_ulb_aligned = thresh_reg.normalize_(pred_wa_ulb_aligned)
pred_wa[unlabeled_inds, ...] = pred_wa_ulb_aligned

batch_dict_ema['batch_cls_preds_org'] = pred_weak_aug_before_nms_org
batch_dict_ema['batch_cls_preds'] = pred_weak_aug_before_nms
batch_dict_ema['batch_cls_preds'] = pred_wa
batch_dict_ema['cls_preds_normalized'] = True

def _gen_pseudo_labels(self, batch_dict_ema, ulb_inds):
Expand Down Expand Up @@ -266,9 +264,8 @@ def _forward_training(self, batch_dict):

# update dynamic thresh results
for tag in self.thresh_registry.tags():
results = self.thresh_registry.get(tag).compute()
if results:
tag = tag + "/" if tag else ''
if results := self.thresh_registry.get(tag).compute():
tag = f"{tag}/" if tag else ''
tb_dict_.update({tag + key: val for key, val in results.items()})

for tag in metrics_registry.tags():
Expand Down
Loading

0 comments on commit 87bfa2c

Please sign in to comment.