From 89f3c48572c08d7722237fb7f09f0bd5e58fc966 Mon Sep 17 00:00:00 2001 From: hyq Date: Wed, 23 Oct 2019 09:52:09 +0000 Subject: [PATCH] fix FP 16 training for retinanet --- maskrcnn_benchmark/layers/sigmoid_focal_loss.py | 16 +++++++++++++--- .../modeling/rpn/retinanet/loss.py | 3 ++- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/maskrcnn_benchmark/layers/sigmoid_focal_loss.py b/maskrcnn_benchmark/layers/sigmoid_focal_loss.py index 22b8018df..54263f910 100644 --- a/maskrcnn_benchmark/layers/sigmoid_focal_loss.py +++ b/maskrcnn_benchmark/layers/sigmoid_focal_loss.py @@ -8,12 +8,16 @@ # TODO: Use JIT to replace CUDA implementation in the future. class _SigmoidFocalLoss(Function): @staticmethod - def forward(ctx, logits, targets, gamma, alpha): + def forward(ctx, logits, targets, gamma, alpha, dtype): + if dtype == 'float16': + logits = logits.float() ctx.save_for_backward(logits, targets) num_classes = logits.shape[1] ctx.num_classes = num_classes ctx.gamma = gamma ctx.alpha = alpha + ctx.dtype = dtype + losses = _C.sigmoid_focalloss_forward( logits, targets, num_classes, gamma, alpha @@ -31,6 +35,9 @@ def backward(ctx, d_loss): d_logits = _C.sigmoid_focalloss_backward( logits, targets, d_loss, num_classes, gamma, alpha ) + if ctx.dtype == 'float16': + d_logits = d_logits.half() + return d_logits, None, None, None, None @@ -39,6 +46,8 @@ def backward(ctx, d_loss): def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): num_classes = logits.shape[1] + gamma = gamma[0] + alpha = alpha[0] dtype = targets.dtype device = targets.device class_range = torch.arange(1, num_classes+1, dtype=dtype, device=device).unsqueeze(0) @@ -51,10 +60,11 @@ def sigmoid_focal_loss_cpu(logits, targets, gamma, alpha): class SigmoidFocalLoss(nn.Module): - def __init__(self, gamma, alpha): + def __init__(self, gamma, alpha, dtype): super(SigmoidFocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha + self.dtype = dtype def forward(self, logits, targets): device = logits.device @@ -63,7 +73,7 @@ def forward(self, logits, targets): else: loss_func = sigmoid_focal_loss_cpu - loss = loss_func(logits, targets, self.gamma, self.alpha) + loss = loss_func(logits, targets, self.gamma, self.alpha,self.dtype) return loss.sum() def __repr__(self): diff --git a/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py b/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py index 080e2153b..e58d7ab13 100644 --- a/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py +++ b/maskrcnn_benchmark/modeling/rpn/retinanet/loss.py @@ -93,7 +93,8 @@ def make_retinanet_loss_evaluator(cfg, box_coder): ) sigmoid_focal_loss = SigmoidFocalLoss( cfg.MODEL.RETINANET.LOSS_GAMMA, - cfg.MODEL.RETINANET.LOSS_ALPHA + cfg.MODEL.RETINANET.LOSS_ALPHA, + cfg.DTYPE ) loss_evaluator = RetinaNetLossComputation(