Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sample_weight support for match models #50

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,5 @@ protoc*
# Generated Docs
docs/source/intro.md
docs/source/proto.html

.vscode/
11 changes: 11 additions & 0 deletions tzrec/datasets/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def __init__(
self,
features: List[BaseFeature],
labels: Optional[List[str]] = None,
sample_weights: Optional[List[str]] = None,
is_training: bool = False,
fg_threads: int = 1,
force_base_data_group: bool = False,
) -> None:
self._features = features
self._labels = labels or []
self._sample_weights = sample_weights or []
self._is_training = is_training
self._force_base_data_group = force_base_data_group

Expand Down Expand Up @@ -153,6 +155,10 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:

for label_name in self._labels:
output_data[label_name] = _to_tensor(input_data[label_name].to_numpy())

for weight in self._sample_weights:
output_data[weight] = _to_tensor(input_data[weight].to_numpy())

return output_data

def _parse_feature_normal(
Expand Down Expand Up @@ -321,11 +327,16 @@ def to_batch(
for label_name in self._labels:
labels[label_name] = input_data[label_name]

sample_weights = {}
for weight in self._sample_weights:
sample_weights[weight] = input_data[weight]

batch = Batch(
dense_features=dense_features,
sparse_features=sparse_features,
sequence_dense_features=sequence_dense_features,
labels=labels,
sample_weights=sample_weights,
# pyre-ignore [6]
batch_size=batch_size,
)
Expand Down
2 changes: 2 additions & 0 deletions tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(
self._data_parser = DataParser(
features=features,
labels=list(data_config.label_fields),
sample_weights=list(data_config.sample_weight_fields),
is_training=self._mode == Mode.TRAIN,
fg_threads=data_config.fg_threads,
force_base_data_group=data_config.force_base_data_group,
Expand All @@ -182,6 +183,7 @@ def __init__(
self._selected_input_names = set()
self._selected_input_names |= self._data_parser.feature_input_names
self._selected_input_names |= set(data_config.label_fields)
self._selected_input_names |= set(data_config.sample_weight_fields)
if self._mode == Mode.PREDICT:
self._selected_input_names |= set(self._reserved_columns)
if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT:
Expand Down
11 changes: 11 additions & 0 deletions tzrec/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,8 @@ class Batch(Pipelineable):
reserves: RecordBatchTensor = field(default_factory=RecordBatchTensor)
# batch_size for input-tile
batch_size: int = field(default=-1)
# sample_weight
sample_weights: Dict[str, torch.Tensor] = field(default_factory=dict)

def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
"""Copy to specified device."""
Expand All @@ -131,6 +133,10 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
},
reserves=self.reserves,
batch_size=self.batch_size,
sample_weights={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sample_weights.items()
},
)

def record_stream(self, stream: torch.Stream) -> None:
Expand All @@ -143,6 +149,8 @@ def record_stream(self, stream: torch.Stream) -> None:
v.record_stream(stream)
for v in self.labels.values():
v.record_stream(stream)
for v in self.sample_weights.values():
v.record_stream(stream)

def pin_memory(self) -> "Batch":
"""Copy to pinned memory."""
Expand Down Expand Up @@ -175,6 +183,7 @@ def pin_memory(self) -> "Batch":
labels={k: v.pin_memory() for k, v in self.labels.items()},
reserves=self.reserves,
batch_size=self.batch_size,
sample_weights={k: v.pin_memory() for k, v in self.sample_weights.items()},
)

def to_dict(
Expand Down Expand Up @@ -203,6 +212,8 @@ def to_dict(
tensor_dict[f"{k}.lengths"] = v.lengths()
for k, v in self.labels.items():
tensor_dict[f"{k}"] = v
for k, v in self.sample_weights.items():
tensor_dict[f"{k}"] = v
if self.batch_size > 0:
tensor_dict["batch_size"] = torch.tensor(self.batch_size, dtype=torch.int64)
return tensor_dict
9 changes: 7 additions & 2 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,18 @@ def _get_dataloader(


def _create_model(
model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> BaseModel:
"""Build model.

Args:
model_config (ModelConfig): easyrec model config.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): list of sample weight names.

Return:
model: a EasyRec Model.
Expand All @@ -234,7 +238,7 @@ def _create_model(
# pyre-ignore [16]
model_cls = BaseModel.create_class(model_cls_name)

model = model_cls(model_config, features, labels)
model = model_cls(model_config, features, labels, sample_weights)
return model


Expand Down Expand Up @@ -538,6 +542,7 @@ def train_and_evaluate(
pipeline_config.model_config,
features,
list(data_config.label_fields),
list(data_config.sample_weight_fields),
)
model = TrainWrapper(model)

Expand Down
12 changes: 9 additions & 3 deletions tzrec/models/dbmtl.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Any, Dict, List, Optional

import torch
from torch import nn
Expand All @@ -31,12 +31,18 @@ class DBMTL(MultiTaskRank):
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels)
super().__init__(model_config, features, labels, sample_weights, **kwargs)
assert model_config.WhichOneof("model") == "dbmtl", (
"invalid model config: %s" % self._model_config.WhichOneof("model")
)
Expand Down
12 changes: 9 additions & 3 deletions tzrec/models/deepfm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Any, Dict, List, Optional

import torch
from torch import nn
Expand All @@ -30,12 +30,18 @@ class DeepFM(RankModel):
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels)
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self.init_input()
self.fm = FactorizationMachine()
if self.embedding_group.has_group("fm"):
Expand Down
6 changes: 4 additions & 2 deletions tzrec/models/dssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from collections import OrderedDict
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -99,8 +99,10 @@ def __init__(
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels)
super().__init__(model_config, features, labels, sample_weights, **kwargs)
name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}

user_group = name_to_feature_group[self._model_config.user_tower.input]
Expand Down
7 changes: 5 additions & 2 deletions tzrec/models/dssm_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from collections import OrderedDict
from typing import Dict, List
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -79,15 +79,18 @@ class DSSMV2(MatchModel):
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""

def __init__(
self,
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels)
super().__init__(model_config, features, labels, sample_weights, **kwargs)
name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}

self.embedding_group = EmbeddingGroup(
Expand Down
22 changes: 18 additions & 4 deletions tzrec/models/match_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -155,14 +155,21 @@ class MatchModel(BaseModel):
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels)
super().__init__(model_config, features, labels, sample_weights, **kwargs)
self._num_class = model_config.num_class
self._label_name = labels[0]
self._sample_weight = sample_weights[0] if sample_weights else sample_weights
self._in_batch_negative = False
self._loss_collection = {}
if self._model_config and hasattr(self._model_config, "in_batch_negative"):
Expand All @@ -188,7 +195,8 @@ def _init_loss_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None:
assert (
loss_type == "softmax_cross_entropy"
), "match model only support softmax_cross_entropy loss now."
tiankongdeguiji marked this conversation as resolved.
Show resolved Hide resolved
self._loss_modules[loss_name] = nn.CrossEntropyLoss()
reduction = "none" if self._sample_weight else "mean"
self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction)

def init_loss(self) -> None:
"""Initialize loss modules."""
Expand All @@ -208,6 +216,9 @@ def _loss_impl(
) -> Dict[str, torch.Tensor]:
losses = {}
label = batch.labels[label_name]
sample_weight = (
batch.sample_weights[self._sample_weight] if self._sample_weight else 1.0
)

loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
Expand All @@ -221,6 +232,9 @@ def _loss_impl(
else:
label = _zero_int_label(pred)
losses[loss_name] = self._loss_modules[loss_name](pred, label)
if self._sample_weight:
losses[loss_name] = torch.mean(losses[loss_name] * sample_weight)

return losses

def loss(
Expand Down
12 changes: 9 additions & 3 deletions tzrec/models/mmoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Any, Dict, List, Optional

import torch
from torch import nn
Expand All @@ -30,12 +30,18 @@ class MMoE(MultiTaskRank):
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__(model_config, features, labels)
super().__init__(model_config, features, labels, sample_weights, **kwargs)

self.init_input()
self.group_name = self.embedding_group.group_names()[0]
Expand Down
15 changes: 12 additions & 3 deletions tzrec/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
from itertools import chain
from queue import Queue
from typing import Dict, Iterable, List, Optional, Tuple
from typing import Any, Dict, Iterable, List, Optional, Tuple

import torch
import torchmetrics
Expand Down Expand Up @@ -40,12 +40,18 @@ class BaseModel(nn.Module, metaclass=_meta_cls):
model_config (ModelConfig): an instance of ModelConfig.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): sample weight names.
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str]
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
**kwargs: Any,
) -> None:
super().__init__()
super().__init__(**kwargs)
self._base_model_config = model_config
self._model_type = model_config.WhichOneof("model")
self._features = features
Expand All @@ -56,6 +62,9 @@ def __init__(
self._metric_modules = nn.ModuleDict()
self._loss_modules = nn.ModuleDict()

if sample_weights:
self._sample_weights = sample_weights

def predict(self, batch: Batch) -> Dict[str, torch.Tensor]:
"""Predict the model.

Expand Down
Loading