Skip to content

Commit

Permalink
Merge branch 'ty_master' into features/zch_emb
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Dec 10, 2024
2 parents cec440b + f38d895 commit 36bb75a
Show file tree
Hide file tree
Showing 9 changed files with 83 additions and 26 deletions.
3 changes: 3 additions & 0 deletions docs/source/feature/data.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ data_config {
}
```

如果希望在训练过程带上样本权重,支持在data_config中增加配置项
sample_weight_fields: 'col_name'

### dataset_type

目前支持一下几种[input_type](../proto.html#tzrec.protos.DatasetType):
Expand Down
1 change: 1 addition & 0 deletions requirements/runtime.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,5 +13,6 @@ pyodps>=0.12.0
scikit-learn
tensorboard
torch==2.5.0
torch-tensorrt @ http://tzrec.oss-cn-beijing.aliyuncs.com/third_party/torch_tensorrt-2.5.0a0-cp311-cp311-linux_x86_64.whl ; python_version=="3.11"
torchmetrics==1.0.3
torchrec==1.0.0
36 changes: 29 additions & 7 deletions tzrec/loss/jrc_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,27 @@ class JRCLoss(_Loss):
Args:
alpha (float): cross entropy loss weight.
same_label_loss (bool): whether use same label jrc loss.
reduction (str, optional): Specifies the reduction to apply to the
output: `none` | `mean`. `none`: no reduction will be applied
, `mean`: the weighted mean of the output is taken.
"""

def __init__(self, alpha: float = 0.5) -> None:
def __init__(
self,
alpha: float = 0.5,
reduction: str = "mean",
) -> None:
super().__init__()
self._alpha = alpha
self._ce_loss = CrossEntropyLoss()

def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor:
self._reduction = reduction
self._ce_loss = CrossEntropyLoss(reduction=reduction)

def forward(
self,
logits: Tensor,
labels: Tensor,
session_ids: Tensor,
) -> Tensor:
"""JRC loss.
Args:
Expand All @@ -50,9 +62,11 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor
session_ids: a `Tensor` with shape [batch_size].
Return:
loss: a `Tensor`.
loss: a `Tensor` with shape [batch_size] if reduction is 'none',
otherwise with shape ().
"""
ce_loss = self._ce_loss(logits, labels)

batch_size = labels.shape[0]
mask = torch.eq(session_ids.unsqueeze(1), session_ids.unsqueeze(0)).float()
diag_index = _diag_index(labels)
Expand Down Expand Up @@ -89,7 +103,15 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor
)
loss_neg = self._ce_loss(logits_neg, neg_diag_label)

ge_loss = (loss_pos * pos_num + loss_neg * neg_num) / batch_size
if self._reduction != "none":
loss_pos = loss_pos * pos_num / batch_size
loss_neg = loss_neg * neg_num / batch_size
ge_loss = loss_pos + loss_neg
else:
ge_loss = torch.zeros_like(labels, dtype=torch.float)
ge_loss.index_put_(torch.where(labels == 1.0), loss_pos)
ge_loss.index_put_(torch.where(labels == 0.0), loss_neg)

loss = self._alpha * ce_loss + (1 - self._alpha) * ge_loss
# pyre-ignore [7]
return loss
23 changes: 23 additions & 0 deletions tzrec/loss/jrc_loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,28 @@ def test_jrc_loss(self) -> None:
self.assertEqual(0.7199, round(loss.item(), 4))


class JRCLossTestReduceNone(unittest.TestCase):
def test_jrc_loss_reduce_none(self) -> None:
loss_class = JRCLoss(reduction="none")
logits = torch.tensor(
[
[0.9, 0.1],
[0.5, 0.5],
[0.3, 0.7],
[0.2, 0.8],
[0.8, 0.2],
[0.55, 0.45],
[0.33, 0.67],
[0.55, 0.45],
],
dtype=torch.float32,
)
labels = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1])
session_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.int8)
loss = loss_class(logits, labels, session_ids)

self.assertEqual(0.7199, round(torch.mean(loss).item(), 4))


if __name__ == "__main__":
unittest.main()
9 changes: 1 addition & 8 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# cpu image has no torch_tensorrt
try:
import torch_tensorrt
except Exception:
pass

import copy
import itertools
import json
Expand Down Expand Up @@ -1042,9 +1038,6 @@ def predict(
if "PYTORCH_TENSOREXPR_FALLBACK" not in os.environ:
os.environ["PYTORCH_TENSOREXPR_FALLBACK"] = "2"

if is_trt_convert:
torch_tensorrt.runtime.set_multi_device_safe_mode(True)

model: torch.jit.ScriptModule = torch.jit.load(
os.path.join(scripted_model_path, "scripted_model.pt"), map_location=device
)
Expand Down
6 changes: 6 additions & 0 deletions tzrec/models/multi_task_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,18 @@ def loss(
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
label_name = task_tower_cfg.label_name
sample_weight_name = (
task_tower_cfg.sample_weight_name
if task_tower_cfg.sample_weight_name
else ""
)
for loss_cfg in task_tower_cfg.losses:
losses.update(
self._loss_impl(
predictions,
batch,
label_name,
sample_weight_name,
loss_cfg,
num_class=task_tower_cfg.num_class,
suffix=f"_{tower_name}",
Expand Down
20 changes: 16 additions & 4 deletions tzrec/models/rank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def __init__(
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self._num_class = model_config.num_class
self._label_name = labels[0]
self._sample_weight_name = (
sample_weights[0] if sample_weights else sample_weights
)
self._loss_collection = {}
self.embedding_group = None
self.group_variational_dropouts = None
Expand Down Expand Up @@ -153,17 +156,18 @@ def _init_loss_impl(
) -> None:
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
reduction = "none" if self._sample_weight_name else "mean"
if loss_type == "binary_cross_entropy":
self._loss_modules[loss_name] = nn.BCEWithLogitsLoss()
self._loss_modules[loss_name] = nn.BCEWithLogitsLoss(reduction=reduction)
elif loss_type == "softmax_cross_entropy":
self._loss_modules[loss_name] = nn.CrossEntropyLoss()
self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction)
elif loss_type == "jrc_loss":
assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}"
self._loss_modules[loss_name] = JRCLoss(
alpha=loss_cfg.jrc_loss.alpha,
alpha=loss_cfg.jrc_loss.alpha, reduction=reduction
)
elif loss_type == "l2_loss":
self._loss_modules[loss_name] = nn.MSELoss()
self._loss_modules[loss_name] = nn.MSELoss(reduction=reduction)
else:
raise ValueError(f"loss[{loss_type}] is not supported yet.")

Expand All @@ -177,12 +181,16 @@ def _loss_impl(
predictions: Dict[str, torch.Tensor],
batch: Batch,
label_name: str,
sample_weight_name: str,
loss_cfg: LossConfig,
num_class: int = 1,
suffix: str = "",
) -> Dict[str, torch.Tensor]:
losses = {}
label = batch.labels[label_name]
sample_weights = (
batch.sample_weights[sample_weight_name] if sample_weight_name else None
)

loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
Expand All @@ -205,6 +213,9 @@ def _loss_impl(
losses[loss_name] = self._loss_modules[loss_name](pred, label)
else:
raise ValueError(f"loss[{loss_type}] is not supported yet.")

if sample_weight_name:
losses[loss_name] = torch.mean(losses[loss_name] * sample_weights)
return losses

def loss(
Expand All @@ -218,6 +229,7 @@ def loss(
predictions,
batch,
self._label_name,
self._sample_weight_name,
loss_cfg,
num_class=self._num_class,
)
Expand Down
4 changes: 4 additions & 0 deletions tzrec/protos/tower.proto
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ message TaskTower {
optional MLP mlp = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// sample weight for the task
optional string sample_weight_name = 8;
};

message BayesTaskTower {
Expand All @@ -56,6 +58,8 @@ message BayesTaskTower {
repeated string relation_tower_names = 8;
// relation mlp
optional MLP relation_mlp = 9;
// sample weight for the task
optional string sample_weight_name = 10;
};

message MultiWindowDINTower {
Expand Down
7 changes: 0 additions & 7 deletions tzrec/tests/rank_integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,6 @@
import unittest

import torch

# cpu image has no torch_tensorrt
try:
import torch_tensorrt
except Exception:
pass
from pyarrow import dataset as ds

from tzrec.constant import Mode
Expand Down Expand Up @@ -628,7 +622,6 @@ def test_multi_tower_with_fg_train_eval_export_trt(self):
utils.save_predict_result_json(result_gpu, result_dict_json_path)

# quant and trt
torch_tensorrt.runtime.set_multi_device_safe_mode(True)
model_gpu_trt = torch.jit.load(
os.path.join(self.test_dir, "trt/export/scripted_model.pt"),
map_location=device,
Expand Down

0 comments on commit 36bb75a

Please sign in to comment.