From 4d6c0a1a3c12dbd05aa050158f39fb2bff7322cf Mon Sep 17 00:00:00 2001 From: gecheng Date: Mon, 16 Dec 2024 18:08:11 +0800 Subject: [PATCH] use mean instead of sum in num/den for better numerical stability --- tzrec/models/match_model.py | 4 ++-- tzrec/models/rank_model.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 447bbf2..5d5c9a0 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -235,9 +235,9 @@ def _loss_impl( label = _zero_int_label(pred) losses[loss_name] = self._loss_modules[loss_name](pred, label) if self._sample_weight: - losses[loss_name] = torch.sum( + losses[loss_name] = torch.mean( losses[loss_name] * sample_weight - ) / torch.sum(sample_weight) + ) / torch.mean(sample_weight) return losses diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 7517790..9ada02c 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -217,9 +217,9 @@ def _loss_impl( raise ValueError(f"loss[{loss_type}] is not supported yet.") if sample_weight_name: - losses[loss_name] = torch.sum( + losses[loss_name] = torch.mean( losses[loss_name] * sample_weights - ) / torch.sum(sample_weights) + ) / torch.mean(sample_weights) return losses def loss(