From c8146b48225fa40eeaf18c7e23149e80d1ab4e25 Mon Sep 17 00:00:00 2001 From: gecheng Date: Mon, 9 Dec 2024 18:20:50 +0800 Subject: [PATCH 1/5] rank model supports sample weight --- docs/source/feature/data.md | 3 +++ tzrec/loss/jrc_loss.py | 35 ++++++++++++++++++++++++++++------- tzrec/models/rank_model.py | 16 ++++++++++++---- 3 files changed, 43 insertions(+), 11 deletions(-) diff --git a/docs/source/feature/data.md b/docs/source/feature/data.md index f374788..4a4d998 100644 --- a/docs/source/feature/data.md +++ b/docs/source/feature/data.md @@ -18,6 +18,9 @@ data_config { } ``` +如果希望在训练过程带上样本权重,支持在data_config中增加配置项 +sample_weight_fields: 'col_name' + ### dataset_type 目前支持一下几种[input_type](../proto.html#tzrec.protos.DatasetType): diff --git a/tzrec/loss/jrc_loss.py b/tzrec/loss/jrc_loss.py index 2afdd72..68d5204 100644 --- a/tzrec/loss/jrc_loss.py +++ b/tzrec/loss/jrc_loss.py @@ -33,15 +33,27 @@ class JRCLoss(_Loss): Args: alpha (float): cross entropy loss weight. - same_label_loss (bool): whether use same label jrc loss. + reduction (str, optional): Specifies the reduction to apply to the + output: `none` | `mean`. `none`: no reduction will be applied + , `mean`: the weighted mean of the output is taken. """ - def __init__(self, alpha: float = 0.5) -> None: + def __init__( + self, + alpha: float = 0.5, + reduction: str = "mean", + ) -> None: super().__init__() self._alpha = alpha - self._ce_loss = CrossEntropyLoss() - - def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor: + self._reduction = reduction + self._ce_loss = CrossEntropyLoss(reduction=reduction) + + def forward( + self, + logits: Tensor, + labels: Tensor, + session_ids: Tensor, + ) -> Tensor: """JRC loss. Args: @@ -50,9 +62,11 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor session_ids: a `Tensor` with shape [batch_size]. Return: - loss: a `Tensor`. + loss: a `Tensor` with shape [batch_size] if reduction is 'none', + otherwise with shape (). """ ce_loss = self._ce_loss(logits, labels) + batch_size = labels.shape[0] mask = torch.eq(session_ids.unsqueeze(1), session_ids.unsqueeze(0)).float() diag_index = _diag_index(labels) @@ -88,8 +102,15 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor logits_neg + ((1 - neg_session_mask) + (1 - diag_neg) * y_neg) * -1e9 ) loss_neg = self._ce_loss(logits_neg, neg_diag_label) + loss_pos = loss_pos * pos_num / batch_size + loss_neg = loss_neg * neg_num / batch_size + if self._reduction != "none": + ge_loss = loss_pos + loss_neg + else: + ge_loss = torch.zeros_like(labels, dtype=torch.float) + ge_loss.index_put_(torch.where(labels == 1.0), loss_pos) + ge_loss.index_put_(torch.where(labels == 0.0), loss_neg) - ge_loss = (loss_pos * pos_num + loss_neg * neg_num) / batch_size loss = self._alpha * ce_loss + (1 - self._alpha) * ge_loss # pyre-ignore [7] return loss diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 830125c..0c2be76 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -56,6 +56,7 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] + self._sample_weight = sample_weights[0] if sample_weights else sample_weights self._loss_collection = {} self.embedding_group = None self.group_variational_dropouts = None @@ -153,17 +154,18 @@ def _init_loss_impl( ) -> None: loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix + reduction = "none" if self._sample_weight else "mean" if loss_type == "binary_cross_entropy": - self._loss_modules[loss_name] = nn.BCEWithLogitsLoss() + self._loss_modules[loss_name] = nn.BCEWithLogitsLoss(reduction=reduction) elif loss_type == "softmax_cross_entropy": - self._loss_modules[loss_name] = nn.CrossEntropyLoss() + self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction) elif loss_type == "jrc_loss": assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}" self._loss_modules[loss_name] = JRCLoss( - alpha=loss_cfg.jrc_loss.alpha, + alpha=loss_cfg.jrc_loss.alpha, reduction=reduction ) elif loss_type == "l2_loss": - self._loss_modules[loss_name] = nn.MSELoss() + self._loss_modules[loss_name] = nn.MSELoss(reduction=reduction) else: raise ValueError(f"loss[{loss_type}] is not supported yet.") @@ -183,6 +185,9 @@ def _loss_impl( ) -> Dict[str, torch.Tensor]: losses = {} label = batch.labels[label_name] + sample_weights = ( + batch.sample_weights[self._sample_weight] if self._sample_weight else None + ) loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix @@ -205,6 +210,9 @@ def _loss_impl( losses[loss_name] = self._loss_modules[loss_name](pred, label) else: raise ValueError(f"loss[{loss_type}] is not supported yet.") + + if self._sample_weight: + losses[loss_name] = torch.mean(losses[loss_name] * sample_weights) return losses def loss( From 536449c426356006a9c8a59cdf8314e33f92647b Mon Sep 17 00:00:00 2001 From: gecheng Date: Tue, 10 Dec 2024 11:03:23 +0800 Subject: [PATCH 2/5] pass sample_weight_name into _loss_impl for multi-task rank model --- tzrec/models/multi_task_rank.py | 6 ++++++ tzrec/models/rank_model.py | 12 ++++++++---- tzrec/protos/tower.proto | 4 ++++ 3 files changed, 18 insertions(+), 4 deletions(-) diff --git a/tzrec/models/multi_task_rank.py b/tzrec/models/multi_task_rank.py index b06fac1..483dc86 100644 --- a/tzrec/models/multi_task_rank.py +++ b/tzrec/models/multi_task_rank.py @@ -76,12 +76,18 @@ def loss( for task_tower_cfg in self._task_tower_cfgs: tower_name = task_tower_cfg.tower_name label_name = task_tower_cfg.label_name + sample_weight_name = ( + task_tower_cfg.sample_weight_name + if task_tower_cfg.sample_weight_name + else "" + ) for loss_cfg in task_tower_cfg.losses: losses.update( self._loss_impl( predictions, batch, label_name, + sample_weight_name, loss_cfg, num_class=task_tower_cfg.num_class, suffix=f"_{tower_name}", diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 0c2be76..b60bbde 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -56,7 +56,9 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] - self._sample_weight = sample_weights[0] if sample_weights else sample_weights + self._sample_weight_name = ( + sample_weights[0] if sample_weights else sample_weights + ) self._loss_collection = {} self.embedding_group = None self.group_variational_dropouts = None @@ -154,7 +156,7 @@ def _init_loss_impl( ) -> None: loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix - reduction = "none" if self._sample_weight else "mean" + reduction = "none" if self._sample_weight_name else "mean" if loss_type == "binary_cross_entropy": self._loss_modules[loss_name] = nn.BCEWithLogitsLoss(reduction=reduction) elif loss_type == "softmax_cross_entropy": @@ -179,6 +181,7 @@ def _loss_impl( predictions: Dict[str, torch.Tensor], batch: Batch, label_name: str, + sample_weight_name: str, loss_cfg: LossConfig, num_class: int = 1, suffix: str = "", @@ -186,7 +189,7 @@ def _loss_impl( losses = {} label = batch.labels[label_name] sample_weights = ( - batch.sample_weights[self._sample_weight] if self._sample_weight else None + batch.sample_weights[sample_weight_name] if sample_weight_name else None ) loss_type = loss_cfg.WhichOneof("loss") @@ -211,7 +214,7 @@ def _loss_impl( else: raise ValueError(f"loss[{loss_type}] is not supported yet.") - if self._sample_weight: + if sample_weight_name: losses[loss_name] = torch.mean(losses[loss_name] * sample_weights) return losses @@ -226,6 +229,7 @@ def loss( predictions, batch, self._label_name, + self._sample_weight_name, loss_cfg, num_class=self._num_class, ) diff --git a/tzrec/protos/tower.proto b/tzrec/protos/tower.proto index 629ab6b..54de9cb 100644 --- a/tzrec/protos/tower.proto +++ b/tzrec/protos/tower.proto @@ -34,6 +34,8 @@ message TaskTower { optional MLP mlp = 6; // training loss weights optional float weight = 7 [default = 1.0]; + // sample weight for the task + optional string sample_weight_name = 8; }; message BayesTaskTower { @@ -56,6 +58,8 @@ message BayesTaskTower { repeated string relation_tower_names = 8; // relation mlp optional MLP relation_mlp = 9; + // sample weight for the task + optional string sample_weight_name = 10; }; message MultiWindowDINTower { From e638c961b89131793955ad3e007d757de34a475e Mon Sep 17 00:00:00 2001 From: gecheng Date: Tue, 10 Dec 2024 14:02:07 +0800 Subject: [PATCH 3/5] add jrc_loss_test with reduction=none --- tzrec/loss/jrc_loss_test.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tzrec/loss/jrc_loss_test.py b/tzrec/loss/jrc_loss_test.py index 203a2c5..c3d51fa 100644 --- a/tzrec/loss/jrc_loss_test.py +++ b/tzrec/loss/jrc_loss_test.py @@ -39,5 +39,32 @@ def test_jrc_loss(self) -> None: self.assertEqual(0.7199, round(loss.item(), 4)) +class JRCLossTestReduceNone(unittest.TestCase): + def test_jrc_loss_reduce_none(self) -> None: + loss_class = JRCLoss(reduction="none") + logits = torch.tensor( + [ + [0.9, 0.1], + [0.5, 0.5], + [0.3, 0.7], + [0.2, 0.8], + [0.8, 0.2], + [0.55, 0.45], + [0.33, 0.67], + [0.55, 0.45], + ], + dtype=torch.float32, + ) + labels = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1]) + session_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.int8) + loss = loss_class(logits, labels, session_ids) + rounded_loss = list(map(lambda x: round(x, 4), loss.numpy().tolist())) + + answer = [0.3644, 0.5815, 0.4720, 0.4201, 0.4380, 0.5798, 0.4905, 0.6277] + diff = sum((map(lambda x, y: abs(x - y), answer, rounded_loss))) + + self.assertEqual(0.0, diff) + + if __name__ == "__main__": unittest.main() From 1d07c121c7102be250f566e457b5fafc54841d08 Mon Sep 17 00:00:00 2001 From: gecheng Date: Tue, 10 Dec 2024 15:22:34 +0800 Subject: [PATCH 4/5] bug fix --- tzrec/loss/jrc_loss.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tzrec/loss/jrc_loss.py b/tzrec/loss/jrc_loss.py index 68d5204..85723cd 100644 --- a/tzrec/loss/jrc_loss.py +++ b/tzrec/loss/jrc_loss.py @@ -102,9 +102,10 @@ def forward( logits_neg + ((1 - neg_session_mask) + (1 - diag_neg) * y_neg) * -1e9 ) loss_neg = self._ce_loss(logits_neg, neg_diag_label) - loss_pos = loss_pos * pos_num / batch_size - loss_neg = loss_neg * neg_num / batch_size + if self._reduction != "none": + loss_pos = loss_pos * pos_num / batch_size + loss_neg = loss_neg * neg_num / batch_size ge_loss = loss_pos + loss_neg else: ge_loss = torch.zeros_like(labels, dtype=torch.float) From cf30073461f92c6428b91ad7407f651739b6bb91 Mon Sep 17 00:00:00 2001 From: gecheng Date: Tue, 10 Dec 2024 16:16:03 +0800 Subject: [PATCH 5/5] jrc_loss_test fix --- tzrec/loss/jrc_loss_test.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tzrec/loss/jrc_loss_test.py b/tzrec/loss/jrc_loss_test.py index c3d51fa..401bf83 100644 --- a/tzrec/loss/jrc_loss_test.py +++ b/tzrec/loss/jrc_loss_test.py @@ -58,12 +58,8 @@ def test_jrc_loss_reduce_none(self) -> None: labels = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1]) session_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.int8) loss = loss_class(logits, labels, session_ids) - rounded_loss = list(map(lambda x: round(x, 4), loss.numpy().tolist())) - answer = [0.3644, 0.5815, 0.4720, 0.4201, 0.4380, 0.5798, 0.4905, 0.6277] - diff = sum((map(lambda x, y: abs(x - y), answer, rounded_loss))) - - self.assertEqual(0.0, diff) + self.assertEqual(0.7199, round(torch.mean(loss).item(), 4)) if __name__ == "__main__":