You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
As far as I understand, from line 182 to 188, we are calculating loss for each image, iterating as much as the batch size.
However, when calling the custom_nll function, weights does not give in batch units, but delivers the whole as a factor.
I think it's right to fix it with weights[i], can you check if I'm right?
# As-is
for i in range(0, inputs.shape[0]):
if not self.batch_weights:
class_weights = self.calculate_weights(target_cpu[i])
loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
target[i].unsqueeze(0),
class_weights=torch.Tensor(class_weights).cuda(),
border_weights=weights, mask=ignore_mask[i])
# To-be
for i in range(0, inputs.shape[0]):
if not self.batch_weights:
class_weights = self.calculate_weights(target_cpu[i])
loss = loss + self.custom_nll(inputs[i].unsqueeze(0),
target[i].unsqueeze(0),
class_weights=torch.Tensor(class_weights).cuda(),
border_weights=weights[i], mask=ignore_mask[i])
2.
At line 192 of loss.py, I think loss should be devided by batch size (i.e. inputs.shape[0]).
In the case of CrossEntropyLoss2d, this criterion is returning the batch average of loss.
# As-is
return loss
# To-be
return loss / inputs.shape[0]
Thank you in advance. :)
The text was updated successfully, but these errors were encountered:
Hello! Thank you for the great work.
I have two questions at loss.py - ImgWtLossSoftNLL.
1.
As far as I understand, from line 182 to 188, we are calculating loss for each image, iterating as much as the batch size.
However, when calling the
custom_nll
function,weights
does not give in batch units, but delivers the whole as a factor.I think it's right to fix it with
weights[i]
, can you check if I'm right?2.
At line 192 of loss.py, I think
loss
should be devided by batch size (i.e.inputs.shape[0]
).In the case of CrossEntropyLoss2d, this criterion is returning the batch average of loss.
Thank you in advance. :)
The text was updated successfully, but these errors were encountered: