-
Notifications
You must be signed in to change notification settings - Fork 4
/
loss.py
219 lines (186 loc) · 8.57 KB
/
loss.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
"""
Utility methods for constructing loss functions
"""
from typing import Optional
from sklearn.utils.class_weight import compute_class_weight
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
# TODO: replace this constant with an argument from the config
DECORRELATE = False
def create_loss(config):
"""
Parse configuration file and return a relevant loss function
"""
if config['model'] == 'MVCBM' or config['model'] == 'CBM':
if config['training_mode'] == 'sequential':
return MVCBLoss(num_classes=config['num_classes'])
elif config['training_mode'] == 'joint':
return MVCBLoss(num_classes=config['num_classes'], alpha=config['alpha'])
elif config["model"] in ["SSMVCBM"]:
return SSMVCBLoss(num_classes=config['num_classes'])
else:
return nn.BCELoss()
class MVCBLoss(nn.Module):
"""
Loss function for the (multiview) concept bottleneck model
"""
# NOTE: this loss function is also applicable to the vanilla CBMs
def __init__(
self,
num_classes: Optional[int] = 2,
target_class_weight: Optional[Tensor] = None,
target_sample_weight: Optional[Tensor] = None,
c_weights: Optional[Tensor] = None,
reduction: str = "mean",
alpha: float = 1) -> None:
"""
Initializes the loss object
@param num_classes: the number of the classes of the target variable
@param target_class_weight: weight per target's class
@param target_sample_weight: target class weights per data point
@param c_weights: concepts weights
@param reduction: reduction to apply to the output of the CE loss
@param alpha: parameter controlling the trade-off between the target and concept prediction during the joint
optimization. The higher the @alpha, the high the weight of the concept prediction loss
"""
super(MVCBLoss, self).__init__()
self.num_classes = num_classes
self.target_class_weight = target_class_weight
# NOTE: these weights will need to be updated every time before the loss is computed
self.target_sample_weight = target_sample_weight
self.c_weights = c_weights
self.reduction = reduction
self.alpha = alpha
def forward(self, concepts_pred: Tensor, concepts_true: Tensor,
target_pred_probs: Tensor, target_pred_logits: Tensor, target_true: Tensor) -> Tensor:
"""
Computes the loss for the given predictions
@param concepts_pred: predicted concept values
@param concepts_true: ground-truth concept values
@param target_pred_probs: predicted probabilities, aka normalized logits, for the target variable
@param target_pred_logits: predicted logits for the target variable
@param target_true: ground-truth target variable values
@return: target prediction loss, a tensor of prediction losses for each of the concepts, summed concept
prediction loss and the total loss
"""
summed_concepts_loss = 0
concepts_loss = []
# NOTE: all concepts are assumed to be binary-valued
# TODO: introduce continuously- and categorically-valued concepts
for concept_idx in range(concepts_true.shape[1]):
w = self.target_sample_weight * self.c_weights[concept_idx] if self.target_sample_weight is not None else None
c_loss = F.binary_cross_entropy(
concepts_pred[:, concept_idx], concepts_true[:, concept_idx].float(), weight=w, reduction=self.reduction)
concepts_loss.append(c_loss)
summed_concepts_loss += c_loss
if self.num_classes == 2:
target_loss = F.binary_cross_entropy(
target_pred_probs, target_true, weight=self.target_sample_weight, reduction=self.reduction)
else:
target_loss = F.cross_entropy(
target_pred_logits, target_true.long(), weight=self.target_class_weight, reduction=self.reduction)
total_loss = target_loss + self.alpha * summed_concepts_loss
return target_loss, concepts_loss, summed_concepts_loss, total_loss
class SSMVCBLoss(nn.Module):
"""
Loss function for the semi-supervised multiview concept bottleneck model
"""
def __init__(
self,
num_classes: Optional[int] = 2,
target_class_weight: Optional[Tensor] = None,
target_sample_weight: Optional[Tensor] = None,
c_weights: Optional[Tensor] = None,
reduction: str = "mean"
) -> None:
"""
Initializes the loss object
@param num_classes: the number of the classes of the target variable
@param target_class_weight: weight per target's class
@param target_sample_weight: target class weights per data point
@param c_weights: concepts weights
@param reduction: reduction to apply to the output of the CE loss
"""
super(SSMVCBLoss, self).__init__()
self.num_classes = num_classes
self.target_class_weight = target_class_weight
# NOTE: these weights will need to be updated every time before the loss is computed
self.target_sample_weight = target_sample_weight
self.c_weights = c_weights
self.reduction = reduction
def forward(self, s_concepts_pred: Tensor, discr_concepts_pred: Tensor, concepts_true: Tensor,
target_pred_probs: Tensor, target_pred_logits: Tensor, target_true: Tensor,
us_concepts_sample: Tensor) -> Tensor:
"""
Computes the loss for the given predictions
@param s_concepts_pred: predicted concept values
@param discr_concepts_pred: concept predictions made by the adversary
@param concepts_true: ground-truth concept values
@param target_pred_probs: predicted probabilities, aka normalized logits, for the target variable
@param target_pred_logits: predicted logits for the target variable
@param target_true: ground-truth target variable values
@param us_concepts_sample: a sample of the unsupervised representations
@return: target prediction loss, a tensor of concept prediction losses, summed concept prediction loss,
summed concept prediction loss for the adversary (its positive and negative values) and
summed representation de-correlation loss
"""
summed_discr_concepts_loss = 0
summed_s_concepts_loss = 0
s_concepts_loss = []
# Supervised concepts loss
# NOTE: all concepts are assumed to be binary-valued
# TODO: introduce continuously- and categorically-valued concepts
for concept_idx in range(concepts_true.shape[1]):
w = self.target_sample_weight * self.c_weights[concept_idx] if self.target_sample_weight is not None else None
c_loss = F.binary_cross_entropy(
s_concepts_pred[:, concept_idx], concepts_true[:, concept_idx].float(), weight=w, reduction=self.reduction)
s_concepts_loss.append(c_loss)
summed_s_concepts_loss += c_loss
# Adversarial loss term
for concept_idx in range(concepts_true.shape[1]):
w = self.c_weights[concept_idx] if self.c_weights is not None else None
c_loss = F.binary_cross_entropy(
discr_concepts_pred[:, concept_idx], s_concepts_pred[:, concept_idx], weight=w, reduction=self.reduction)
summed_discr_concepts_loss += c_loss
# Unsupervised representation loss term
summed_gen_concepts_loss = -summed_discr_concepts_loss
# Compute covariance among the dimensions of the unsupervised representation
# NOTE: can cause issues during the optimisation
# NOTE: this loss term is disabled in the current implementation
if DECORRELATE:
cov = torch.cov(us_concepts_sample.T)
else:
cov = torch.zeros((us_concepts_sample.shape[1], us_concepts_sample.shape[1]))
cov = cov.fill_diagonal_(0)
us_corr_loss = torch.square(torch.linalg.matrix_norm(cov))
# Target prediction loss term
if self.num_classes == 2:
target_loss = F.binary_cross_entropy(
target_pred_probs, target_true, weight=self.target_sample_weight, reduction=self.reduction)
else:
target_loss = F.cross_entropy(
target_pred_logits, target_true.long(), weight=self.target_class_weight, reduction=self.reduction)
return target_loss, s_concepts_loss, summed_s_concepts_loss, summed_discr_concepts_loss, \
summed_gen_concepts_loss, us_corr_loss
def calc_concept_weights(all_c):
"""
Computes class weights for every list element all_c[i] corresponding to a set of the i-th concept values
"""
concepts_class_weights = []
for concept_idx in range(len(all_c)):
c_class_weights = compute_class_weight(class_weight="balanced", classes=[0, 1], y=all_c[concept_idx])
concepts_class_weights.append(c_class_weights)
return concepts_class_weights
def calc_concept_sample_weights(config, concepts_class_weights, batch_concepts):
"""
Assigns precomputed concept class weights to every sample in a batch.
"""
concepts_sample_weights = []
for concept_idx in range(len(concepts_class_weights)):
c_sample_weights = [concepts_class_weights[concept_idx][0] if int(
batch_concepts[i][concept_idx]) == 0 else concepts_class_weights[concept_idx][1] for i in
range(len(batch_concepts))]
concepts_sample_weights.append(torch.FloatTensor(c_sample_weights).to(config["device"]))
return concepts_sample_weights