Skip to content
This repository has been archived by the owner on Oct 31, 2023. It is now read-only.

on-line hard example mining-resubmit #845

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions maskrcnn_benchmark/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@
_C.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE = 512
# Target fraction of RoI minibatch that is labeled foreground (i.e. class > 0)
_C.MODEL.ROI_HEADS.POSITIVE_FRACTION = 0.25
# whether to use hard-mining
_C.MODEL.ROI_HEADS.OHEM = False

# Only used on test mode

Expand Down
12 changes: 8 additions & 4 deletions maskrcnn_benchmark/layers/smooth_l1_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@


# TODO maybe push this to nn?
def smooth_l1_loss(input, target, beta=1. / 9, size_average=True):
def smooth_l1_loss(input, target, beta=1. / 9, reduction='mean'):
"""
very similar to the smooth_l1_loss from pytorch, but with
the extra beta parameter
"""
n = torch.abs(input - target)
cond = n < beta
loss = torch.where(cond, 0.5 * n ** 2 / beta, n - 0.5 * beta)
if size_average:
return loss.mean()
return loss.sum()
if reduction == 'mean':
loss = loss.mean()
elif reduction == 'sum':
loss = loss.sum()
elif reduction == 'none':
pass
return loss
50 changes: 50 additions & 0 deletions maskrcnn_benchmark/modeling/balanced_positive_negative_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,3 +66,53 @@ def __call__(self, matched_idxs):
neg_idx.append(neg_idx_per_image_mask)

return pos_idx, neg_idx


class OhemPositiveNegativeSampler(object):
"""
This class samples batches, ensuring that they contain a fixed proportion of positives
"""

def __init__(self):
"""
Arguments:
batch_size_per_image (int): number of elements to be selected per image
"""
pass

def __call__(self, matched_idxs):
"""
Arguments:
matched idxs: list of tensors containing -1, 0 or positive values.
Each tensor corresponds to a specific image.
-1 values are ignored, 0 are considered as negatives and > 0 as
positives.

Returns:
pos_idx (list[tensor])
neg_idx (list[tensor])

Returns two lists of binary masks for each image.
The first list contains the positive elements that were selected,
and the second list the negative example.
"""
pos_idx = []
neg_idx = []
for matched_idxs_per_image in matched_idxs:
pos_idx_per_image = torch.nonzero(matched_idxs_per_image >= 1).squeeze(1)
neg_idx_per_image = torch.nonzero(matched_idxs_per_image == 0).squeeze(1)

# create binary mask from indices
pos_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
neg_idx_per_image_mask = torch.zeros_like(
matched_idxs_per_image, dtype=torch.uint8
)
pos_idx_per_image_mask[pos_idx_per_image] = 1
neg_idx_per_image_mask[neg_idx_per_image] = 1

pos_idx.append(pos_idx_per_image_mask)
neg_idx.append(neg_idx_per_image_mask)

return pos_idx, neg_idx
7 changes: 7 additions & 0 deletions maskrcnn_benchmark/modeling/roi_heads/box_head/box_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ def __init__(self, cfg, in_channels):
cfg, self.feature_extractor.out_channels)
self.post_processor = make_roi_box_post_processor(cfg)
self.loss_evaluator = make_roi_box_loss_evaluator(cfg)
self.ohem = cfg.MODEL.ROI_HEADS.OHEM

def forward(self, features, proposals, targets=None):
"""
Expand All @@ -41,6 +42,12 @@ def forward(self, features, proposals, targets=None):
# positive / negative ratio
with torch.no_grad():
proposals = self.loss_evaluator.subsample(proposals, targets)
if self.ohem:
x = self.feature_extractor(features, proposals)
class_logits, box_regression = self.predictor(x)
proposals = self.loss_evaluator.mining(
[class_logits], [box_regression]
)

# extract features that will be fed to the final classifier. The
# feature_extractor generally corresponds to the pooler + heads
Expand Down
83 changes: 77 additions & 6 deletions maskrcnn_benchmark/modeling/roi_heads/box_head/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from maskrcnn_benchmark.modeling.matcher import Matcher
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.modeling.balanced_positive_negative_sampler import (
BalancedPositiveNegativeSampler
BalancedPositiveNegativeSampler,
OhemPositiveNegativeSampler
)
from maskrcnn_benchmark.modeling.utils import cat

Expand All @@ -23,17 +24,19 @@ def __init__(
proposal_matcher,
fg_bg_sampler,
box_coder,
batch_size_per_image,
cls_agnostic_bbox_reg=False
):
"""
Arguments:
proposal_matcher (Matcher)
fg_bg_sampler (BalancedPositiveNegativeSampler)
fg_bg_sampler (BalancedPositiveNegativeSampler, or OhemPositiveNegativeSampler)
box_coder (BoxCoder)
"""
self.proposal_matcher = proposal_matcher
self.fg_bg_sampler = fg_bg_sampler
self.box_coder = box_coder
self.batch_size_per_image = batch_size_per_image
self.cls_agnostic_bbox_reg = cls_agnostic_bbox_reg

def match_targets_to_proposals(self, proposal, target):
Expand Down Expand Up @@ -105,16 +108,79 @@ def subsample(self, proposals, targets):

# distributed sampled proposals, that were obtained on all feature maps
# concatenated via the fg_bg_sampler, into individual feature map levels
self.n_proposals_per_img = []
for img_idx, (pos_inds_img, neg_inds_img) in enumerate(
zip(sampled_pos_inds, sampled_neg_inds)
):
img_sampled_inds = torch.nonzero(pos_inds_img | neg_inds_img).squeeze(1)
proposals_per_image = proposals[img_idx][img_sampled_inds]
self.n_proposals_per_img.append(len(proposals_per_image))
proposals[img_idx] = proposals_per_image

self._proposals = proposals
return proposals

def mining(self, class_logits, box_regression):
"""
Similiar role as sumsample(), but return the rois with top loss.

Arguments:
class_logits (list[Tensor])
box_regression (list[Tensor])

Returns:
proposals (list[BoxList])
"""

class_logits = cat(class_logits, dim=0)
box_regression = cat(box_regression, dim=0)
device = class_logits.device

if not hasattr(self, "_proposals"):
raise RuntimeError("subsample needs to be called before")

proposals = self._proposals

labels = cat([proposal.get_field("labels") for proposal in proposals], dim=0)
regression_targets = cat(
[proposal.get_field("regression_targets") for proposal in proposals], dim=0
)

classification_loss = F.cross_entropy(class_logits, labels, reduction='none')

# get indices that correspond to the regression targets for
# the corresponding ground truth labels, to be used with
# advanced indexing
sampled_pos_inds_subset = torch.nonzero(labels > 0).squeeze(1)
labels_pos = labels[sampled_pos_inds_subset]
if self.cls_agnostic_bbox_reg:
map_inds = torch.tensor([4, 5, 6, 7], device=device)
else:
map_inds = 4 * labels_pos[:, None] + torch.tensor(
[0, 1, 2, 3], device=device)

box_loss = smooth_l1_loss(
box_regression[sampled_pos_inds_subset[:, None], map_inds],
regression_targets[sampled_pos_inds_subset],
reduction='none',
beta=1,
).sum(dim=1, keepdim=True)
ohem_loss = classification_loss.clone()
ohem_loss[sampled_pos_inds_subset[:, None]] = ohem_loss[sampled_pos_inds_subset[:, None]] + box_loss
if ohem_loss.size(0) > self.batch_size_per_image:
ohem_idx = ohem_loss.topk(self.batch_size_per_image)[1].cpu()
lengs = [0,] + self.n_proposals_per_img
milestones = torch.cumsum(torch.tensor(lengs), dim=0)
ms1 = milestones[:-1]
ms2 = milestones[1:]
ohem_idx = ohem_idx.sort()[0]
lengs = [torch.sum((el1 <= ohem_idx)*(ohem_idx < el2)) for el1, el2 in zip(ms1, ms2)]
ohem_idx = ohem_idx.split(lengs)
ohem_idx = [el-ms1[i] for i, el in enumerate(ohem_idx)]
self._proposals = [proposals[i][el] for i, el in enumerate(ohem_idx)]

return self._proposals

def __call__(self, class_logits, box_regression):
"""
Computes the loss for Faster R-CNN.
Expand Down Expand Up @@ -159,7 +225,7 @@ def __call__(self, class_logits, box_regression):
box_loss = smooth_l1_loss(
box_regression[sampled_pos_inds_subset[:, None], map_inds],
regression_targets[sampled_pos_inds_subset],
size_average=False,
reduction='sum',
beta=1,
)
box_loss = box_loss / labels.numel()
Expand All @@ -177,16 +243,21 @@ def make_roi_box_loss_evaluator(cfg):
bbox_reg_weights = cfg.MODEL.ROI_HEADS.BBOX_REG_WEIGHTS
box_coder = BoxCoder(weights=bbox_reg_weights)

fg_bg_sampler = BalancedPositiveNegativeSampler(
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE, cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
)
if cfg.MODEL.ROI_HEADS.OHEM:
fg_bg_sampler = OhemPositiveNegativeSampler()
else:
fg_bg_sampler = BalancedPositiveNegativeSampler(
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE,
cfg.MODEL.ROI_HEADS.POSITIVE_FRACTION
)

cls_agnostic_bbox_reg = cfg.MODEL.CLS_AGNOSTIC_BBOX_REG

loss_evaluator = FastRCNNLossComputation(
matcher,
fg_bg_sampler,
box_coder,
cfg.MODEL.ROI_HEADS.BATCH_SIZE_PER_IMAGE,
cls_agnostic_bbox_reg
)

Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/modeling/rpn/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def __call__(self, anchors, objectness, box_regression, targets):
box_regression[sampled_pos_inds],
regression_targets[sampled_pos_inds],
beta=1.0 / 9,
size_average=False,
reduction='sum',
) / (sampled_inds.numel())

objectness_loss = F.binary_cross_entropy_with_logits(
Expand Down
2 changes: 1 addition & 1 deletion maskrcnn_benchmark/modeling/rpn/retinanet/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def __call__(self, anchors, box_cls, box_regression, targets):
box_regression[pos_inds],
regression_targets[pos_inds],
beta=self.bbox_reg_beta,
size_average=False,
reduction='sum',
) / (max(1, pos_inds.numel() * self.regress_norm))

labels = labels.int()
Expand Down