forked from ChunML/ssd-tf2
-
Notifications
You must be signed in to change notification settings - Fork 0
/
losses.py
executable file
·92 lines (72 loc) · 2.96 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import tensorflow as tf
def hard_negative_mining(loss, gt_confs, neg_ratio):
""" Hard negative mining algorithm
to pick up negative examples for back-propagation
base on classification loss values
Args:
loss: list of classification losses of all default boxes (B, num_default)
gt_confs: classification targets (B, num_default)
neg_ratio: negative / positive ratio
Returns:
conf_loss: classification loss
loc_loss: regression loss
"""
# loss: B x N
# gt_confs: B x N
pos_idx = gt_confs > 0
num_pos = tf.reduce_sum(tf.dtypes.cast(pos_idx, tf.int32), axis=1)
num_neg = num_pos * neg_ratio
rank = tf.argsort(loss, axis=1, direction='DESCENDING')
rank = tf.argsort(rank, axis=1)
neg_idx = rank < tf.expand_dims(num_neg, 1)
return pos_idx, neg_idx
class SSDLosses(object):
""" Class for SSD Losses
Attributes:
neg_ratio: negative / positive ratio
num_classes: number of classes
"""
def __init__(self, neg_ratio, num_classes):
self.neg_ratio = neg_ratio
self.num_classes = num_classes
def __call__(self, confs, locs, gt_confs, gt_locs):
""" Compute losses for SSD
regression loss: smooth L1
classification loss: cross entropy
Args:
confs: outputs of classification heads (B, num_default, num_classes)
locs: outputs of regression heads (B, num_default, 4)
gt_confs: classification targets (B, num_default)
gt_locs: regression targets (B, num_default, 4)
Returns:
conf_loss: classification loss
loc_loss: regression loss
"""
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='none')
# compute classification losses
# without reduction
temp_loss = cross_entropy(
gt_confs, confs)
pos_idx, neg_idx = hard_negative_mining(
temp_loss, gt_confs, self.neg_ratio)
# classification loss will consist of positive and negative examples
cross_entropy = tf.keras.losses.SparseCategoricalCrossentropy(
from_logits=True, reduction='sum')
smooth_l1_loss = tf.keras.losses.Huber(reduction='sum')
conf_loss = cross_entropy(
gt_confs[tf.math.logical_or(pos_idx, neg_idx)],
confs[tf.math.logical_or(pos_idx, neg_idx)])
# regression loss only consist of positive examples
loc_loss = smooth_l1_loss(
# tf.boolean_mask(gt_locs, pos_idx),
# tf.boolean_mask(locs, pos_idx))
gt_locs[pos_idx],
locs[pos_idx])
num_pos = tf.reduce_sum(tf.dtypes.cast(pos_idx, tf.float32))
conf_loss = conf_loss / num_pos
loc_loss = loc_loss / num_pos
return conf_loss, loc_loss
def create_losses(neg_ratio, num_classes):
criterion = SSDLosses(neg_ratio, num_classes)
return criterion