Skip to content

Commit

Permalink
use mean instead of sum in num/den for better numerical stability
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng committed Dec 16, 2024
1 parent 5702227 commit 4d6c0a1
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tzrec/models/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,9 @@ def _loss_impl(
label = _zero_int_label(pred)
losses[loss_name] = self._loss_modules[loss_name](pred, label)
if self._sample_weight:
losses[loss_name] = torch.sum(
losses[loss_name] = torch.mean(
losses[loss_name] * sample_weight
) / torch.sum(sample_weight)
) / torch.mean(sample_weight)

return losses

Expand Down
4 changes: 2 additions & 2 deletions tzrec/models/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,9 +217,9 @@ def _loss_impl(
raise ValueError(f"loss[{loss_type}] is not supported yet.")

if sample_weight_name:
losses[loss_name] = torch.sum(
losses[loss_name] = torch.mean(
losses[loss_name] * sample_weights
) / torch.sum(sample_weights)
) / torch.mean(sample_weights)
return losses

def loss(
Expand Down

0 comments on commit 4d6c0a1

Please sign in to comment.