Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat]rank model supports sample weight #57

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/source/feature/data.md
Original file line number Diff line number Diff line change
@@ -18,6 +18,9 @@ data_config {
}
```

如果希望在训练过程带上样本权重,支持在data_config中增加配置项
sample_weight_fields: 'col_name'

### dataset_type

目前支持一下几种[input_type](../proto.html#tzrec.protos.DatasetType):
36 changes: 29 additions & 7 deletions tzrec/loss/jrc_loss.py
Original file line number Diff line number Diff line change
@@ -33,15 +33,27 @@ class JRCLoss(_Loss):

Args:
alpha (float): cross entropy loss weight.
same_label_loss (bool): whether use same label jrc loss.
reduction (str, optional): Specifies the reduction to apply to the
output: `none` | `mean`. `none`: no reduction will be applied
, `mean`: the weighted mean of the output is taken.
"""

def __init__(self, alpha: float = 0.5) -> None:
def __init__(
self,
alpha: float = 0.5,
reduction: str = "mean",
) -> None:
super().__init__()
self._alpha = alpha
self._ce_loss = CrossEntropyLoss()

def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor:
self._reduction = reduction
self._ce_loss = CrossEntropyLoss(reduction=reduction)

def forward(
self,
logits: Tensor,
labels: Tensor,
session_ids: Tensor,
) -> Tensor:
"""JRC loss.

Args:
@@ -50,9 +62,11 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor
session_ids: a `Tensor` with shape [batch_size].

Return:
loss: a `Tensor`.
loss: a `Tensor` with shape [batch_size] if reduction is 'none',
otherwise with shape ().
"""
ce_loss = self._ce_loss(logits, labels)

batch_size = labels.shape[0]
mask = torch.eq(session_ids.unsqueeze(1), session_ids.unsqueeze(0)).float()
diag_index = _diag_index(labels)
@@ -89,7 +103,15 @@ def forward(self, logits: Tensor, labels: Tensor, session_ids: Tensor) -> Tensor
)
loss_neg = self._ce_loss(logits_neg, neg_diag_label)

ge_loss = (loss_pos * pos_num + loss_neg * neg_num) / batch_size
if self._reduction != "none":
loss_pos = loss_pos * pos_num / batch_size
loss_neg = loss_neg * neg_num / batch_size
ge_loss = loss_pos + loss_neg
else:
ge_loss = torch.zeros_like(labels, dtype=torch.float)
tiankongdeguiji marked this conversation as resolved.
Show resolved Hide resolved
ge_loss.index_put_(torch.where(labels == 1.0), loss_pos)
ge_loss.index_put_(torch.where(labels == 0.0), loss_neg)

loss = self._alpha * ce_loss + (1 - self._alpha) * ge_loss
# pyre-ignore [7]
return loss
23 changes: 23 additions & 0 deletions tzrec/loss/jrc_loss_test.py
Original file line number Diff line number Diff line change
@@ -39,5 +39,28 @@ def test_jrc_loss(self) -> None:
self.assertEqual(0.7199, round(loss.item(), 4))


class JRCLossTestReduceNone(unittest.TestCase):
def test_jrc_loss_reduce_none(self) -> None:
loss_class = JRCLoss(reduction="none")
logits = torch.tensor(
[
[0.9, 0.1],
[0.5, 0.5],
[0.3, 0.7],
[0.2, 0.8],
[0.8, 0.2],
[0.55, 0.45],
[0.33, 0.67],
[0.55, 0.45],
],
dtype=torch.float32,
)
labels = torch.tensor([0, 0, 1, 1, 0, 0, 1, 1])
session_ids = torch.tensor([1, 1, 1, 1, 2, 2, 2, 2], dtype=torch.int8)
loss = loss_class(logits, labels, session_ids)

self.assertEqual(0.7199, round(torch.mean(loss).item(), 4))


if __name__ == "__main__":
unittest.main()
6 changes: 6 additions & 0 deletions tzrec/models/multi_task_rank.py
Original file line number Diff line number Diff line change
@@ -76,12 +76,18 @@ def loss(
for task_tower_cfg in self._task_tower_cfgs:
tower_name = task_tower_cfg.tower_name
label_name = task_tower_cfg.label_name
sample_weight_name = (
task_tower_cfg.sample_weight_name
if task_tower_cfg.sample_weight_name
else ""
)
for loss_cfg in task_tower_cfg.losses:
losses.update(
self._loss_impl(
predictions,
batch,
label_name,
sample_weight_name,
loss_cfg,
num_class=task_tower_cfg.num_class,
suffix=f"_{tower_name}",
20 changes: 16 additions & 4 deletions tzrec/models/rank_model.py
Original file line number Diff line number Diff line change
@@ -56,6 +56,9 @@ def __init__(
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self._num_class = model_config.num_class
self._label_name = labels[0]
self._sample_weight_name = (
sample_weights[0] if sample_weights else sample_weights
)
self._loss_collection = {}
self.embedding_group = None
self.group_variational_dropouts = None
@@ -153,17 +156,18 @@ def _init_loss_impl(
) -> None:
loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
reduction = "none" if self._sample_weight_name else "mean"
if loss_type == "binary_cross_entropy":
self._loss_modules[loss_name] = nn.BCEWithLogitsLoss()
self._loss_modules[loss_name] = nn.BCEWithLogitsLoss(reduction=reduction)
elif loss_type == "softmax_cross_entropy":
self._loss_modules[loss_name] = nn.CrossEntropyLoss()
self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction)
elif loss_type == "jrc_loss":
assert num_class == 2, f"num_class must be 2 when loss type is {loss_type}"
self._loss_modules[loss_name] = JRCLoss(
alpha=loss_cfg.jrc_loss.alpha,
alpha=loss_cfg.jrc_loss.alpha, reduction=reduction
)
elif loss_type == "l2_loss":
self._loss_modules[loss_name] = nn.MSELoss()
self._loss_modules[loss_name] = nn.MSELoss(reduction=reduction)
else:
raise ValueError(f"loss[{loss_type}] is not supported yet.")

@@ -177,12 +181,16 @@ def _loss_impl(
predictions: Dict[str, torch.Tensor],
batch: Batch,
label_name: str,
sample_weight_name: str,
loss_cfg: LossConfig,
num_class: int = 1,
suffix: str = "",
) -> Dict[str, torch.Tensor]:
tiankongdeguiji marked this conversation as resolved.
Show resolved Hide resolved
losses = {}
label = batch.labels[label_name]
sample_weights = (
batch.sample_weights[sample_weight_name] if sample_weight_name else None
)

loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
@@ -205,6 +213,9 @@ def _loss_impl(
losses[loss_name] = self._loss_modules[loss_name](pred, label)
else:
raise ValueError(f"loss[{loss_type}] is not supported yet.")

if sample_weight_name:
losses[loss_name] = torch.mean(losses[loss_name] * sample_weights)
return losses

def loss(
@@ -218,6 +229,7 @@ def loss(
predictions,
batch,
self._label_name,
self._sample_weight_name,
loss_cfg,
num_class=self._num_class,
)
4 changes: 4 additions & 0 deletions tzrec/protos/tower.proto
Original file line number Diff line number Diff line change
@@ -34,6 +34,8 @@ message TaskTower {
optional MLP mlp = 6;
// training loss weights
optional float weight = 7 [default = 1.0];
// sample weight for the task
optional string sample_weight_name = 8;
};

message BayesTaskTower {
@@ -56,6 +58,8 @@ message BayesTaskTower {
repeated string relation_tower_names = 8;
// relation mlp
optional MLP relation_mlp = 9;
// sample weight for the task
optional string sample_weight_name = 10;
};

message MultiWindowDINTower {