diff --git a/docs/source/feature/data.md b/docs/source/feature/data.md index 0ae1aef..88d205d 100644 --- a/docs/source/feature/data.md +++ b/docs/source/feature/data.md @@ -18,6 +18,9 @@ data_config { } ``` +如果希望在训练过程带上样本权重,支持在data_config中增加配置项 +sample_weight_fields: 'col_name' + ### dataset_type 目前支持一下几种[input_type](../proto.html#tzrec.protos.DatasetType): diff --git a/requirements/runtime.txt b/requirements/runtime.txt index 93a7298..b052263 100644 --- a/requirements/runtime.txt +++ b/requirements/runtime.txt @@ -13,5 +13,6 @@ pyodps>=0.12.0 scikit-learn tensorboard torch==2.5.0 +torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11" torchmetrics==1.0.3 torchrec==1.0.0 diff --git a/tzrec/loss/jrc_loss.py b/tzrec/loss/jrc_loss.py index 2afdd72..85723cd 100644 --- a/tzrec/loss/jrc_loss.py +++ b/tzrec/loss/jrc_loss.py @@ -33,15 +33,27 @@ class JRCLoss(_Loss): Args: alpha (float): cross entropy loss weight. - same_label_loss (bool): whether use same label jrc loss. + reduction (str, optional): Specifies the reduction to apply to the + output: `none` | `mean`. `none`: no reduction will be applied + , `mean`: the weighted mean of the output is taken. """ - def __init__(self, alpha: float = 0.5) -> None: + def __init__( + self, + alpha: float = 0.5, + reduction: str = "mean", + ) -> None: super().__init__() self._alpha = alpha - self._ce_loss = CrossEntropyLoss() - - def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor: + self._reduction = reduction + self._ce_loss = CrossEntropyLoss(reduction=reduction) + + def forward( + self, + logits: Tensor, + labels: Tensor, + session_ids: Tensor, + ) -> Tensor: """JRC loss. Args: @@ -50,9 +62,11 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor session_ids: a `Tensor` with shape [batch_size]. Return: - loss: a `Tensor`. + loss: a `Tensor` with shape [batch_size] if reduction is 'none', + otherwise with shape (). """ ce_loss = self._ce_loss(logits, labels) + batch_size = labels.shape[0] mask = torch.eq(session_ids.unsqueeze(1), session_ids.unsqueeze(0)).float() diag_index = _diag_index(labels) @@ -89,7 +103,15 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor ) loss_neg = self._ce_loss(logits_neg, neg_diag_label) - ge_loss = (loss_pos * pos_num + loss_neg * neg_num) / batch_size + if self._reduction != "none": + loss_pos = loss_pos * pos_num / batch_size + loss_neg = loss_neg * neg_num / batch_size + ge_loss = loss_pos + loss_neg + else: + ge_loss = torch.zeros_like(labels, dtype=torch.float) + ge_loss.index_put_(torch.where(labels == 1.0), loss_pos) + ge_loss.index_put_(torch.where(labels == 0.0), loss_neg) + loss = self._alpha * ce_loss + (1 - self._alpha) * ge_loss # pyre-ignore [7] return loss diff --git a/tzrec/loss/jrc_loss_test.py b/tzrec/loss/jrc_loss_test.py index 203a2c5..401bf83 100644 --- a/tzrec/loss/jrc_loss_test.py +++ b/tzrec/loss/jrc_loss_test.py @@ -39,5 +39,28 @@ def test_jrc_loss(self) -> None: self.assertEqual(0.7199, round(loss.item(), 4)) +class JRCLossTestReduceNone(unittest.TestCase): + def test_jrc_loss_reduce_none(self) -> None: + loss_class = JRCLoss(reduction="none") + logits = torch.tensor( + [ + [0.9, 0.1], + [0.5, 0.5], + [0.3, 0.7], + [0.2, 0.8], + [0.8, 0.2], + [0.55, 0.45], + [0.33, 0.67], + [0.55, 0.45], + ], + dtype=torch.float32, + ) + labels = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1]) + session_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.int8) + loss = loss_class(logits, labels, session_ids) + + self.assertEqual(0.7199, round(torch.mean(loss).item(), 4)) + + if __name__ == "__main__": unittest.main() diff --git a/tzrec/main.py b/tzrec/main.py index 962f09a..b176d4a 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -8,11 +8,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -# cpu image has no torch_tensorrt -try: - import torch_tensorrt -except Exception: - pass + import copy import itertools import json @@ -1042,9 +1038,6 @@ def predict( if "PYTORCH_TENSOREXPR_FALLBACK" not in os.environ: os.environ["PYTORCH_TENSOREXPR_FALLBACK"] = "2" - if is_trt_convert: - torch_tensorrt.runtime.set_multi_device_safe_mode(True) - model: torch.jit.ScriptModule = torch.jit.load( os.path.join(scripted_model_path, "scripted_model.pt"), map_location=device ) diff --git a/tzrec/models/multi_task_rank.py b/tzrec/models/multi_task_rank.py index b06fac1..483dc86 100644 --- a/tzrec/models/multi_task_rank.py +++ b/tzrec/models/multi_task_rank.py @@ -76,12 +76,18 @@ def loss( for task_tower_cfg in self._task_tower_cfgs: tower_name = task_tower_cfg.tower_name label_name = task_tower_cfg.label_name + sample_weight_name = ( + task_tower_cfg.sample_weight_name + if task_tower_cfg.sample_weight_name + else "" + ) for loss_cfg in task_tower_cfg.losses: losses.update( self._loss_impl( predictions, batch, label_name, + sample_weight_name, loss_cfg, num_class=task_tower_cfg.num_class, suffix=f"_{tower_name}", diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 830125c..b60bbde 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -56,6 +56,9 @@ def __init__( super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] + self._sample_weight_name = ( + sample_weights[0] if sample_weights else sample_weights + ) self._loss_collection = {} self.embedding_group = None self.group_variational_dropouts = None @@ -153,17 +156,18 @@ def _init_loss_impl( ) -> None: loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix + reduction = "none" if self._sample_weight_name else "mean" if loss_type == "binary_cross_entropy": - self._loss_modules[loss_name] = nn.BCEWithLogitsLoss() + self._loss_modules[loss_name] = nn.BCEWithLogitsLoss(reduction=reduction) elif loss_type == "softmax_cross_entropy": - self._loss_modules[loss_name] = nn.CrossEntropyLoss() + self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction) elif loss_type == "jrc_loss": assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}" self._loss_modules[loss_name] = JRCLoss( - alpha=loss_cfg.jrc_loss.alpha, + alpha=loss_cfg.jrc_loss.alpha, reduction=reduction ) elif loss_type == "l2_loss": - self._loss_modules[loss_name] = nn.MSELoss() + self._loss_modules[loss_name] = nn.MSELoss(reduction=reduction) else: raise ValueError(f"loss[{loss_type}] is not supported yet.") @@ -177,12 +181,16 @@ def _loss_impl( predictions: Dict[str, torch.Tensor], batch: Batch, label_name: str, + sample_weight_name: str, loss_cfg: LossConfig, num_class: int = 1, suffix: str = "", ) -> Dict[str, torch.Tensor]: losses = {} label = batch.labels[label_name] + sample_weights = ( + batch.sample_weights[sample_weight_name] if sample_weight_name else None + ) loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix @@ -205,6 +213,9 @@ def _loss_impl( losses[loss_name] = self._loss_modules[loss_name](pred, label) else: raise ValueError(f"loss[{loss_type}] is not supported yet.") + + if sample_weight_name: + losses[loss_name] = torch.mean(losses[loss_name] * sample_weights) return losses def loss( @@ -218,6 +229,7 @@ def loss( predictions, batch, self._label_name, + self._sample_weight_name, loss_cfg, num_class=self._num_class, ) diff --git a/tzrec/protos/tower.proto b/tzrec/protos/tower.proto index 629ab6b..54de9cb 100644 --- a/tzrec/protos/tower.proto +++ b/tzrec/protos/tower.proto @@ -34,6 +34,8 @@ message TaskTower { optional MLP mlp = 6; // training loss weights optional float weight = 7 [default = 1.0]; + // sample weight for the task + optional string sample_weight_name = 8; }; message BayesTaskTower { @@ -56,6 +58,8 @@ message BayesTaskTower { repeated string relation_tower_names = 8; // relation mlp optional MLP relation_mlp = 9; + // sample weight for the task + optional string sample_weight_name = 10; }; message MultiWindowDINTower { diff --git a/tzrec/tests/rank_integration_test.py b/tzrec/tests/rank_integration_test.py index d9df533..7f7f303 100644 --- a/tzrec/tests/rank_integration_test.py +++ b/tzrec/tests/rank_integration_test.py @@ -16,12 +16,6 @@ import unittest import torch - -# cpu image has no torch_tensorrt -try: - import torch_tensorrt -except Exception: - pass from pyarrow import dataset as ds from tzrec.constant import Mode @@ -628,7 +622,6 @@ def test_multi_tower_with_fg_train_eval_export_trt(self): utils.save_predict_result_json(result_gpu, result_dict_json_path) # quant and trt - torch_tensorrt.runtime.set_multi_device_safe_mode(True) model_gpu_trt = torch.jit.load( os.path.join(self.test_dir, "trt/export/scripted_model.pt"), map_location=device,