diff --git a/.gitignore b/.gitignore index 613b0bb..9a9dc5a 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,5 @@ protoc* # Generated Docs docs/source/intro.md docs/source/proto.html + +.vscode/ diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index 7cc509a..8c8f60e 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -54,12 +54,14 @@ def __init__( self, features: List[BaseFeature], labels: Optional[List[str]] = None, + sample_weights: Optional[List[str]] = None, is_training: bool = False, fg_threads: int = 1, force_base_data_group: bool = False, ) -> None: self._features = features self._labels = labels or [] + self._sample_weights = sample_weights or [] self._is_training = is_training self._force_base_data_group = force_base_data_group @@ -153,6 +155,10 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: for label_name in self._labels: output_data[label_name] = _to_tensor(input_data[label_name].to_numpy()) + + for weight in self._sample_weights: + output_data[weight] = _to_tensor(input_data[weight].to_numpy()) + return output_data def _parse_feature_normal( @@ -321,11 +327,16 @@ def to_batch( for label_name in self._labels: labels[label_name] = input_data[label_name] + sample_weights = {} + for weight in self._sample_weights: + sample_weights[weight] = input_data[weight] + batch = Batch( dense_features=dense_features, sparse_features=sparse_features, sequence_dense_features=sequence_dense_features, labels=labels, + sample_weights=sample_weights, # pyre-ignore [6] batch_size=batch_size, ) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 1cf9792..e3e7275 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -173,6 +173,7 @@ def __init__( self._data_parser = DataParser( features=features, labels=list(data_config.label_fields), + sample_weights=list(data_config.sample_weight_fields), is_training=self._mode == Mode.TRAIN, fg_threads=data_config.fg_threads, force_base_data_group=data_config.force_base_data_group, @@ -182,6 +183,7 @@ def __init__( self._selected_input_names = set() self._selected_input_names |= self._data_parser.feature_input_names self._selected_input_names |= set(data_config.label_fields) + self._selected_input_names |= set(data_config.sample_weight_fields) if self._mode == Mode.PREDICT: self._selected_input_names |= set(self._reserved_columns) if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 0997b47..2fec937 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -109,6 +109,8 @@ class Batch(Pipelineable): reserves: RecordBatchTensor = field(default_factory=RecordBatchTensor) # batch_size for input-tile batch_size: int = field(default=-1) + # sample_weight + sample_weights: Dict[str, torch.Tensor] = field(default_factory=dict) def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": """Copy to specified device.""" @@ -131,6 +133,10 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": }, reserves=self.reserves, batch_size=self.batch_size, + sample_weights={ + k: v.to(device=device, non_blocking=non_blocking) + for k, v in self.sample_weights.items() + }, ) def record_stream(self, stream: torch.Stream) -> None: @@ -143,6 +149,8 @@ def record_stream(self, stream: torch.Stream) -> None: v.record_stream(stream) for v in self.labels.values(): v.record_stream(stream) + for v in self.sample_weights.values(): + v.record_stream(stream) def pin_memory(self) -> "Batch": """Copy to pinned memory.""" @@ -175,6 +183,7 @@ def pin_memory(self) -> "Batch": labels={k: v.pin_memory() for k, v in self.labels.items()}, reserves=self.reserves, batch_size=self.batch_size, + sample_weights={k: v.pin_memory() for k, v in self.sample_weights.items()}, ) def to_dict( @@ -203,6 +212,8 @@ def to_dict( tensor_dict[f"{k}.lengths"] = v.lengths() for k, v in self.labels.items(): tensor_dict[f"{k}"] = v + for k, v in self.sample_weights.items(): + tensor_dict[f"{k}"] = v if self.batch_size > 0: tensor_dict["batch_size"] = torch.tensor(self.batch_size, dtype=torch.int64) return tensor_dict diff --git a/tzrec/main.py b/tzrec/main.py index 2619621..510b174 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -218,7 +218,10 @@ def _get_dataloader( def _create_model( - model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, ) -> BaseModel: """Build model. @@ -226,6 +229,7 @@ def _create_model( model_config (ModelConfig): easyrec model config. features (list): list of features. labels (list): list of label names. + sample_weights (list): list of sample weight names. Return: model: a EasyRec Model. @@ -234,7 +238,7 @@ def _create_model( # pyre-ignore [16] model_cls = BaseModel.create_class(model_cls_name) - model = model_cls(model_config, features, labels) + model = model_cls(model_config, features, labels, sample_weights) return model @@ -538,6 +542,7 @@ def train_and_evaluate( pipeline_config.model_config, features, list(data_config.label_fields), + list(data_config.sample_weight_fields), ) model = TrainWrapper(model) diff --git a/tzrec/models/dbmtl.py b/tzrec/models/dbmtl.py index 9603a30..0100244 100644 --- a/tzrec/models/dbmtl.py +++ b/tzrec/models/dbmtl.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -31,12 +31,18 @@ class DBMTL(MultiTaskRank): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) assert model_config.WhichOneof("model") == "dbmtl", ( "invalid model config: %s" % self._model_config.WhichOneof("model") ) diff --git a/tzrec/models/deepfm.py b/tzrec/models/deepfm.py index 638c788..12b05d2 100644 --- a/tzrec/models/deepfm.py +++ b/tzrec/models/deepfm.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -30,12 +30,18 @@ class DeepFM(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.fm = FactorizationMachine() if self.embedding_group.has_group("fm"): diff --git a/tzrec/models/dssm.py b/tzrec/models/dssm.py index 5c4ae79..47dbc3c 100644 --- a/tzrec/models/dssm.py +++ b/tzrec/models/dssm.py @@ -10,7 +10,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -99,8 +99,10 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} user_group = name_to_feature_group[self._model_config.user_tower.input] diff --git a/tzrec/models/dssm_v2.py b/tzrec/models/dssm_v2.py index 3146f6e..ff54f7a 100644 --- a/tzrec/models/dssm_v2.py +++ b/tzrec/models/dssm_v2.py @@ -10,7 +10,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -79,6 +79,7 @@ class DSSMV2(MatchModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( @@ -86,8 +87,10 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} self.embedding_group = EmbeddingGroup( diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 691dc1a..8530beb 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -155,14 +155,21 @@ class MatchModel(BaseModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] + self._sample_weight = sample_weights[0] if sample_weights else sample_weights self._in_batch_negative = False self._loss_collection = {} if self._model_config and hasattr(self._model_config, "in_batch_negative"): @@ -188,7 +195,8 @@ def _init_loss_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None: assert ( loss_type == "softmax_cross_entropy" ), "match model only support softmax_cross_entropy loss now." - self._loss_modules[loss_name] = nn.CrossEntropyLoss() + reduction = "none" if self._sample_weight else "mean" + self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction) def init_loss(self) -> None: """Initialize loss modules.""" @@ -208,6 +216,9 @@ def _loss_impl( ) -> Dict[str, torch.Tensor]: losses = {} label = batch.labels[label_name] + sample_weight = ( + batch.sample_weights[self._sample_weight] if self._sample_weight else 1.0 + ) loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix @@ -221,6 +232,9 @@ def _loss_impl( else: label = _zero_int_label(pred) losses[loss_name] = self._loss_modules[loss_name](pred, label) + if self._sample_weight: + losses[loss_name] = torch.mean(losses[loss_name] * sample_weight) + return losses def loss( diff --git a/tzrec/models/mmoe.py b/tzrec/models/mmoe.py index 9ee3d06..146e5a7 100644 --- a/tzrec/models/mmoe.py +++ b/tzrec/models/mmoe.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -30,12 +30,18 @@ class MMoE(MultiTaskRank): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.group_name = self.embedding_group.group_names()[0] diff --git a/tzrec/models/model.py b/tzrec/models/model.py index d80c0c7..463dae9 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -12,7 +12,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from itertools import chain from queue import Queue -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch import torchmetrics @@ -40,12 +40,18 @@ class BaseModel(nn.Module, metaclass=_meta_cls): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__() + super().__init__(**kwargs) self._base_model_config = model_config self._model_type = model_config.WhichOneof("model") self._features = features @@ -56,6 +62,9 @@ def __init__( self._metric_modules = nn.ModuleDict() self._loss_modules = nn.ModuleDict() + if sample_weights: + self._sample_weights = sample_weights + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. diff --git a/tzrec/models/multi_task_rank.py b/tzrec/models/multi_task_rank.py index 665873e..b06fac1 100644 --- a/tzrec/models/multi_task_rank.py +++ b/tzrec/models/multi_task_rank.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch @@ -26,12 +26,18 @@ class MultiTaskRank(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self._task_tower_cfgs = list(self._model_config.task_towers) def _multi_task_output_to_prediction( diff --git a/tzrec/models/multi_tower.py b/tzrec/models/multi_tower.py index a926469..69f009b 100644 --- a/tzrec/models/multi_tower.py +++ b/tzrec/models/multi_tower.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -29,12 +29,18 @@ class MultiTower(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.towers = nn.ModuleDict() diff --git a/tzrec/models/multi_tower_din.py b/tzrec/models/multi_tower_din.py index a720e4b..14ef43e 100644 --- a/tzrec/models/multi_tower_din.py +++ b/tzrec/models/multi_tower_din.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -30,12 +30,18 @@ class MultiTowerDIN(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.towers = nn.ModuleDict() diff --git a/tzrec/models/multi_tower_din_trt.py b/tzrec/models/multi_tower_din_trt.py index cbcc36f..a16476a 100644 --- a/tzrec/models/multi_tower_din_trt.py +++ b/tzrec/models/multi_tower_din_trt.py @@ -10,7 +10,7 @@ # limitations under the License. # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -45,6 +45,7 @@ class MultiTowerDINDense(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( @@ -53,8 +54,10 @@ def __init__( model_config: ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.grouped_features_keys = embedding_group.grouped_features_keys() @@ -123,12 +126,18 @@ class MultiTowerDINTRT(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) diff --git a/tzrec/models/ple.py b/tzrec/models/ple.py index 2717f67..9026012 100644 --- a/tzrec/models/ple.py +++ b/tzrec/models/ple.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -31,12 +31,18 @@ class PLE(MultiTaskRank): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) assert model_config.WhichOneof("model") == "ple", ( "invalid model config: %s" % self._model_config.WhichOneof("model") ) diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 63f6f29..830125c 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch import torchmetrics from torch import nn -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch, Optional +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature from tzrec.loss.jrc_loss import JRCLoss from tzrec.metrics.grouped_auc import GroupedAUC @@ -42,6 +42,7 @@ class RankModel(BaseModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( @@ -49,8 +50,10 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] self._loss_collection = {} diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index b54f61b..cdbe4f5 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -32,12 +32,18 @@ class TDM(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs: Any, ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) diff --git a/tzrec/protos/data.proto b/tzrec/protos/data.proto index eccbf61..7ed9abc 100644 --- a/tzrec/protos/data.proto +++ b/tzrec/protos/data.proto @@ -73,6 +73,9 @@ message DataConfig { // force padding data into same data group with same batch_size optional bool force_base_data_group = 18 [default = false]; + // sample weights + repeated string sample_weight_fields = 19; + // negative sampler oneof sampler { NegativeSampler negative_sampler = 101;