Skip to content

Commit

Permalink
fix calculation of weighted loss
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng committed Dec 16, 2024
1 parent 5097d1d commit d7e5380
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion tzrec/models/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,9 @@ def _loss_impl(
label = _zero_int_label(pred)
losses[loss_name] = self._loss_modules[loss_name](pred, label)
if self._sample_weight:
losses[loss_name] = torch.mean(losses[loss_name] * sample_weight)
losses[loss_name] = torch.sum(
losses[loss_name] * sample_weight
) / torch.sum(sample_weight)

return losses

Expand Down
4 changes: 3 additions & 1 deletion tzrec/models/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,9 @@ def _loss_impl(
raise ValueError(f"loss[{loss_type}] is not supported yet.")

if sample_weight_name:
losses[loss_name] = torch.mean(losses[loss_name] * sample_weights)
losses[loss_name] = torch.sum(
losses[loss_name] * sample_weights
) / torch.sum(sample_weights)
return losses

def loss(
Expand Down

0 comments on commit d7e5380

Please sign in to comment.