From 868f7c45559c67ac16323ae41cefec37e1f85b88 Mon Sep 17 00:00:00 2001 From: gecheng Date: Tue, 3 Dec 2024 11:46:40 +0000 Subject: [PATCH 1/5] add sample_weight support for match models --- .gitignore | 2 ++ tzrec/datasets/data_parser.py | 11 +++++++++++ tzrec/datasets/dataset.py | 2 ++ tzrec/datasets/utils.py | 11 +++++++++++ tzrec/main.py | 5 +++-- tzrec/models/dbmtl.py | 5 +++-- tzrec/models/deepfm.py | 5 +++-- tzrec/models/dssm.py | 3 ++- tzrec/models/dssm_v2.py | 4 +++- tzrec/models/match_model.py | 11 +++++++---- tzrec/models/mmoe.py | 5 +++-- tzrec/models/model.py | 6 +++++- tzrec/models/multi_task_rank.py | 5 +++-- tzrec/models/multi_tower.py | 5 +++-- tzrec/models/multi_tower_din.py | 5 +++-- tzrec/models/multi_tower_din_trt.py | 4 +++- tzrec/models/ple.py | 5 +++-- tzrec/models/rank_model.py | 4 +++- tzrec/models/tdm.py | 5 +++-- tzrec/protos/data.proto | 3 +++ 20 files changed, 79 insertions(+), 27 deletions(-) diff --git a/.gitignore b/.gitignore index 613b0bb..e417f22 100644 --- a/.gitignore +++ b/.gitignore @@ -38,3 +38,5 @@ protoc* # Generated Docs docs/source/intro.md docs/source/proto.html + +.vscode/ \ No newline at end of file diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index 7cc509a..db16995 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -54,12 +54,14 @@ def __init__( self, features: List[BaseFeature], labels: Optional[List[str]] = None, + sample_weights: Optional[List[str]] = None, is_training: bool = False, fg_threads: int = 1, force_base_data_group: bool = False, ) -> None: self._features = features self._labels = labels or [] + self._sample_weights = sample_weights or [] self._is_training = is_training self._force_base_data_group = force_base_data_group @@ -153,6 +155,10 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: for label_name in self._labels: output_data[label_name] = _to_tensor(input_data[label_name].to_numpy()) + + for weight in self._sample_weights: + output_data[weight] = _to_tensor(input_data[weight].to_numpy()) + return output_data def _parse_feature_normal( @@ -320,12 +326,17 @@ def to_batch( labels = {} for label_name in self._labels: labels[label_name] = input_data[label_name] + + sample_weights = {} + for weight in self._sample_weights: + sample_weights[weight] = input_data[weight] batch = Batch( dense_features=dense_features, sparse_features=sparse_features, sequence_dense_features=sequence_dense_features, labels=labels, + sample_weights=sample_weights, # pyre-ignore [6] batch_size=batch_size, ) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 1cf9792..e3e7275 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -173,6 +173,7 @@ def __init__( self._data_parser = DataParser( features=features, labels=list(data_config.label_fields), + sample_weights=list(data_config.sample_weight_fields), is_training=self._mode == Mode.TRAIN, fg_threads=data_config.fg_threads, force_base_data_group=data_config.force_base_data_group, @@ -182,6 +183,7 @@ def __init__( self._selected_input_names = set() self._selected_input_names |= self._data_parser.feature_input_names self._selected_input_names |= set(data_config.label_fields) + self._selected_input_names |= set(data_config.sample_weight_fields) if self._mode == Mode.PREDICT: self._selected_input_names |= set(self._reserved_columns) if self._data_config.HasField("sampler") and self._mode != Mode.PREDICT: diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 0997b47..3fa5d1d 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -109,6 +109,8 @@ class Batch(Pipelineable): reserves: RecordBatchTensor = field(default_factory=RecordBatchTensor) # batch_size for input-tile batch_size: int = field(default=-1) + # sample_weight + sample_weights: Dict[str, torch.Tensor] = field(default_factory=dict) def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": """Copy to specified device.""" @@ -131,6 +133,10 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": }, reserves=self.reserves, batch_size=self.batch_size, + sample_weights={ + k: v.to(device=device, non_blocking=non_blocking) + for k, v in self.sample_weights.items() + } ) def record_stream(self, stream: torch.Stream) -> None: @@ -143,6 +149,8 @@ def record_stream(self, stream: torch.Stream) -> None: v.record_stream(stream) for v in self.labels.values(): v.record_stream(stream) + for v in self.sample_weights.values(): + v.record_stream(stream) def pin_memory(self) -> "Batch": """Copy to pinned memory.""" @@ -175,6 +183,7 @@ def pin_memory(self) -> "Batch": labels={k: v.pin_memory() for k, v in self.labels.items()}, reserves=self.reserves, batch_size=self.batch_size, + sample_weights={k: v.pin_memory() for k, v in self.sample_weights.items()}, ) def to_dict( @@ -203,6 +212,8 @@ def to_dict( tensor_dict[f"{k}.lengths"] = v.lengths() for k, v in self.labels.items(): tensor_dict[f"{k}"] = v + for k, v in self.sample_weights.items(): + tensor_dict[f"{k}"] = v if self.batch_size > 0: tensor_dict["batch_size"] = torch.tensor(self.batch_size, dtype=torch.int64) return tensor_dict diff --git a/tzrec/main.py b/tzrec/main.py index 2619621..a1efbeb 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -218,7 +218,7 @@ def _get_dataloader( def _create_model( - model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> BaseModel: """Build model. @@ -234,7 +234,7 @@ def _create_model( # pyre-ignore [16] model_cls = BaseModel.create_class(model_cls_name) - model = model_cls(model_config, features, labels) + model = model_cls(model_config, features, labels, sample_weights) return model @@ -538,6 +538,7 @@ def train_and_evaluate( pipeline_config.model_config, features, list(data_config.label_fields), + list(data_config.sample_weight_fields) ) model = TrainWrapper(model) diff --git a/tzrec/models/dbmtl.py b/tzrec/models/dbmtl.py index 9603a30..762db12 100644 --- a/tzrec/models/dbmtl.py +++ b/tzrec/models/dbmtl.py @@ -31,12 +31,13 @@ class DBMTL(MultiTaskRank): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) assert model_config.WhichOneof("model") == "dbmtl", ( "invalid model config: %s" % self._model_config.WhichOneof("model") ) diff --git a/tzrec/models/deepfm.py b/tzrec/models/deepfm.py index 638c788..4faaf61 100644 --- a/tzrec/models/deepfm.py +++ b/tzrec/models/deepfm.py @@ -30,12 +30,13 @@ class DeepFM(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self.init_input() self.fm = FactorizationMachine() if self.embedding_group.has_group("fm"): diff --git a/tzrec/models/dssm.py b/tzrec/models/dssm.py index 5c4ae79..16721fb 100644 --- a/tzrec/models/dssm.py +++ b/tzrec/models/dssm.py @@ -99,8 +99,9 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} user_group = name_to_feature_group[self._model_config.user_tower.input] diff --git a/tzrec/models/dssm_v2.py b/tzrec/models/dssm_v2.py index 3146f6e..33a3cf0 100644 --- a/tzrec/models/dssm_v2.py +++ b/tzrec/models/dssm_v2.py @@ -79,6 +79,7 @@ class DSSMV2(MatchModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( @@ -86,8 +87,9 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} self.embedding_group = EmbeddingGroup( diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 691dc1a..aacd0bd 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -155,14 +155,16 @@ class MatchModel(BaseModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self._num_class = model_config.num_class self._label_name = labels[0] + self._sample_weight = sample_weights[0] if sample_weights else sample_weights self._in_batch_negative = False self._loss_collection = {} if self._model_config and hasattr(self._model_config, "in_batch_negative"): @@ -188,7 +190,7 @@ def _init_loss_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None: assert ( loss_type == "softmax_cross_entropy" ), "match model only support softmax_cross_entropy loss now." - self._loss_modules[loss_name] = nn.CrossEntropyLoss() + self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction='none') def init_loss(self) -> None: """Initialize loss modules.""" @@ -208,6 +210,7 @@ def _loss_impl( ) -> Dict[str, torch.Tensor]: losses = {} label = batch.labels[label_name] + sample_weight = batch.sample_weights[self._sample_weight] if self._sample_weight else 1.0 loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix @@ -220,7 +223,7 @@ def _loss_impl( label = _arange_int_label(pred) else: label = _zero_int_label(pred) - losses[loss_name] = self._loss_modules[loss_name](pred, label) + losses[loss_name] = torch.mean(self._loss_modules[loss_name](pred, label) * sample_weight) return losses def loss( diff --git a/tzrec/models/mmoe.py b/tzrec/models/mmoe.py index 9ee3d06..30c9ec0 100644 --- a/tzrec/models/mmoe.py +++ b/tzrec/models/mmoe.py @@ -30,12 +30,13 @@ class MMoE(MultiTaskRank): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self.init_input() self.group_name = self.embedding_group.group_names()[0] diff --git a/tzrec/models/model.py b/tzrec/models/model.py index d80c0c7..bf1373b 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -40,10 +40,11 @@ class BaseModel(nn.Module, metaclass=_meta_cls): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: super().__init__() self._base_model_config = model_config @@ -56,6 +57,9 @@ def __init__( self._metric_modules = nn.ModuleDict() self._loss_modules = nn.ModuleDict() + if sample_weights: + self._sample_weights = sample_weights + def predict(self, batch: Batch) -> Dict[str, torch.Tensor]: """Predict the model. diff --git a/tzrec/models/multi_task_rank.py b/tzrec/models/multi_task_rank.py index 665873e..f1a9f25 100644 --- a/tzrec/models/multi_task_rank.py +++ b/tzrec/models/multi_task_rank.py @@ -26,12 +26,13 @@ class MultiTaskRank(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self._task_tower_cfgs = list(self._model_config.task_towers) def _multi_task_output_to_prediction( diff --git a/tzrec/models/multi_tower.py b/tzrec/models/multi_tower.py index a926469..d5ca397 100644 --- a/tzrec/models/multi_tower.py +++ b/tzrec/models/multi_tower.py @@ -29,12 +29,13 @@ class MultiTower(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self.init_input() self.towers = nn.ModuleDict() diff --git a/tzrec/models/multi_tower_din.py b/tzrec/models/multi_tower_din.py index a720e4b..8bae3d1 100644 --- a/tzrec/models/multi_tower_din.py +++ b/tzrec/models/multi_tower_din.py @@ -30,12 +30,13 @@ class MultiTowerDIN(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self.init_input() self.towers = nn.ModuleDict() diff --git a/tzrec/models/multi_tower_din_trt.py b/tzrec/models/multi_tower_din_trt.py index cbcc36f..de49cc2 100644 --- a/tzrec/models/multi_tower_din_trt.py +++ b/tzrec/models/multi_tower_din_trt.py @@ -45,6 +45,7 @@ class MultiTowerDINDense(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( @@ -53,8 +54,9 @@ def __init__( model_config: ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self.grouped_features_keys = embedding_group.grouped_features_keys() diff --git a/tzrec/models/ple.py b/tzrec/models/ple.py index 2717f67..3488dd0 100644 --- a/tzrec/models/ple.py +++ b/tzrec/models/ple.py @@ -31,12 +31,13 @@ class PLE(MultiTaskRank): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) assert model_config.WhichOneof("model") == "ple", ( "invalid model config: %s" % self._model_config.WhichOneof("model") ) diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 63f6f29..60d54e0 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -42,6 +42,7 @@ class RankModel(BaseModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( @@ -49,8 +50,9 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], + sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self._num_class = model_config.num_class self._label_name = labels[0] self._loss_collection = {} diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index b54f61b..09c1284 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -32,12 +32,13 @@ class TDM(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) diff --git a/tzrec/protos/data.proto b/tzrec/protos/data.proto index eccbf61..7ed9abc 100644 --- a/tzrec/protos/data.proto +++ b/tzrec/protos/data.proto @@ -73,6 +73,9 @@ message DataConfig { // force padding data into same data group with same batch_size optional bool force_base_data_group = 18 [default = false]; + // sample weights + repeated string sample_weight_fields = 19; + // negative sampler oneof sampler { NegativeSampler negative_sampler = 101; From 2cc79c75618678b066ad1b92706cfa0e9da9a283 Mon Sep 17 00:00:00 2001 From: gecheng Date: Wed, 4 Dec 2024 07:06:42 +0000 Subject: [PATCH 2/5] bug fix --- tzrec/main.py | 2 +- tzrec/models/dbmtl.py | 2 +- tzrec/models/deepfm.py | 2 +- tzrec/models/dssm.py | 2 +- tzrec/models/dssm_v2.py | 2 +- tzrec/models/match_model.py | 2 +- tzrec/models/mmoe.py | 2 +- tzrec/models/model.py | 2 +- tzrec/models/multi_task_rank.py | 2 +- tzrec/models/multi_tower.py | 2 +- tzrec/models/multi_tower_din.py | 2 +- tzrec/models/multi_tower_din_trt.py | 7 ++++--- tzrec/models/ple.py | 2 +- tzrec/models/rank_model.py | 2 +- tzrec/models/tdm.py | 2 +- 15 files changed, 18 insertions(+), 17 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index a1efbeb..db73f04 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -218,7 +218,7 @@ def _get_dataloader( def _create_model( - model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> BaseModel: """Build model. diff --git a/tzrec/models/dbmtl.py b/tzrec/models/dbmtl.py index 762db12..d70bc5b 100644 --- a/tzrec/models/dbmtl.py +++ b/tzrec/models/dbmtl.py @@ -35,7 +35,7 @@ class DBMTL(MultiTaskRank): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) assert model_config.WhichOneof("model") == "dbmtl", ( diff --git a/tzrec/models/deepfm.py b/tzrec/models/deepfm.py index 4faaf61..4c07946 100644 --- a/tzrec/models/deepfm.py +++ b/tzrec/models/deepfm.py @@ -34,7 +34,7 @@ class DeepFM(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) self.init_input() diff --git a/tzrec/models/dssm.py b/tzrec/models/dssm.py index 16721fb..3ae9c53 100644 --- a/tzrec/models/dssm.py +++ b/tzrec/models/dssm.py @@ -99,7 +99,7 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = None + sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} diff --git a/tzrec/models/dssm_v2.py b/tzrec/models/dssm_v2.py index 33a3cf0..6633073 100644 --- a/tzrec/models/dssm_v2.py +++ b/tzrec/models/dssm_v2.py @@ -87,7 +87,7 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = None + sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index aacd0bd..9992be8 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -159,7 +159,7 @@ class MatchModel(BaseModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) self._num_class = model_config.num_class diff --git a/tzrec/models/mmoe.py b/tzrec/models/mmoe.py index 30c9ec0..e6f472e 100644 --- a/tzrec/models/mmoe.py +++ b/tzrec/models/mmoe.py @@ -34,7 +34,7 @@ class MMoE(MultiTaskRank): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) diff --git a/tzrec/models/model.py b/tzrec/models/model.py index bf1373b..28dd883 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -44,7 +44,7 @@ class BaseModel(nn.Module, metaclass=_meta_cls): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__() self._base_model_config = model_config diff --git a/tzrec/models/multi_task_rank.py b/tzrec/models/multi_task_rank.py index f1a9f25..49e2122 100644 --- a/tzrec/models/multi_task_rank.py +++ b/tzrec/models/multi_task_rank.py @@ -30,7 +30,7 @@ class MultiTaskRank(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) self._task_tower_cfgs = list(self._model_config.task_towers) diff --git a/tzrec/models/multi_tower.py b/tzrec/models/multi_tower.py index d5ca397..ae71070 100644 --- a/tzrec/models/multi_tower.py +++ b/tzrec/models/multi_tower.py @@ -33,7 +33,7 @@ class MultiTower(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) diff --git a/tzrec/models/multi_tower_din.py b/tzrec/models/multi_tower_din.py index 8bae3d1..3bec000 100644 --- a/tzrec/models/multi_tower_din.py +++ b/tzrec/models/multi_tower_din.py @@ -34,7 +34,7 @@ class MultiTowerDIN(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) diff --git a/tzrec/models/multi_tower_din_trt.py b/tzrec/models/multi_tower_din_trt.py index de49cc2..eb1b6ec 100644 --- a/tzrec/models/multi_tower_din_trt.py +++ b/tzrec/models/multi_tower_din_trt.py @@ -54,7 +54,7 @@ def __init__( model_config: ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = None + sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) @@ -125,12 +125,13 @@ class MultiTowerDINTRT(RankModel): model_config (ModelConfig): an instance of ModelConfig. features (list): list of features. labels (list): list of label names. + sample_weights (list): sample weight names. """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str] + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: - super().__init__(model_config, features, labels) + super().__init__(model_config, features, labels, sample_weights) self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) diff --git a/tzrec/models/ple.py b/tzrec/models/ple.py index 3488dd0..3e2d3e1 100644 --- a/tzrec/models/ple.py +++ b/tzrec/models/ple.py @@ -35,7 +35,7 @@ class PLE(MultiTaskRank): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) assert model_config.WhichOneof("model") == "ple", ( diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 60d54e0..c16f125 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -50,7 +50,7 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = None + sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) self._num_class = model_config.num_class diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index 09c1284..f11004c 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -36,7 +36,7 @@ class TDM(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = None + self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] ) -> None: super().__init__(model_config, features, labels, sample_weights) self.embedding_group = EmbeddingGroup( From 2fcc96ab740f0022d056a7add19de165a141d3a7 Mon Sep 17 00:00:00 2001 From: gecheng Date: Wed, 4 Dec 2024 17:07:31 +0800 Subject: [PATCH 3/5] bug fix --- tzrec/models/dbmtl.py | 11 ++++++++--- tzrec/models/deepfm.py | 11 ++++++++--- tzrec/models/dssm.py | 5 +++-- tzrec/models/dssm_v2.py | 7 ++++--- tzrec/models/match_model.py | 21 ++++++++++++++++----- tzrec/models/mmoe.py | 11 ++++++++--- tzrec/models/model.py | 9 +++++++-- tzrec/models/multi_task_rank.py | 9 +++++++-- tzrec/models/multi_tower.py | 11 ++++++++--- tzrec/models/multi_tower_din.py | 11 ++++++++--- tzrec/models/multi_tower_din_trt.py | 16 +++++++++++----- tzrec/models/ple.py | 11 ++++++++--- tzrec/models/rank_model.py | 9 +++++---- tzrec/models/tdm.py | 11 ++++++++--- 14 files changed, 109 insertions(+), 44 deletions(-) diff --git a/tzrec/models/dbmtl.py b/tzrec/models/dbmtl.py index d70bc5b..e850120 100644 --- a/tzrec/models/dbmtl.py +++ b/tzrec/models/dbmtl.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -35,9 +35,14 @@ class DBMTL(MultiTaskRank): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) assert model_config.WhichOneof("model") == "dbmtl", ( "invalid model config: %s" % self._model_config.WhichOneof("model") ) diff --git a/tzrec/models/deepfm.py b/tzrec/models/deepfm.py index 4c07946..3874046 100644 --- a/tzrec/models/deepfm.py +++ b/tzrec/models/deepfm.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -34,9 +34,14 @@ class DeepFM(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.fm = FactorizationMachine() if self.embedding_group.has_group("fm"): diff --git a/tzrec/models/dssm.py b/tzrec/models/dssm.py index 3ae9c53..247f35e 100644 --- a/tzrec/models/dssm.py +++ b/tzrec/models/dssm.py @@ -99,9 +99,10 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = [] + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} user_group = name_to_feature_group[self._model_config.user_tower.input] diff --git a/tzrec/models/dssm_v2.py b/tzrec/models/dssm_v2.py index 6633073..4b9c806 100644 --- a/tzrec/models/dssm_v2.py +++ b/tzrec/models/dssm_v2.py @@ -10,7 +10,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Dict, List +from typing import Dict, List, Optional import torch import torch.nn.functional as F @@ -87,9 +87,10 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = [] + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} self.embedding_group = EmbeddingGroup( diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 9992be8..5217565 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -159,9 +159,14 @@ class MatchModel(BaseModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] self._sample_weight = sample_weights[0] if sample_weights else sample_weights @@ -190,7 +195,8 @@ def _init_loss_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None: assert ( loss_type == "softmax_cross_entropy" ), "match model only support softmax_cross_entropy loss now." - self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction='none') + reduction = "none" if self._sample_weight else "mean" + self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction=reduction) def init_loss(self) -> None: """Initialize loss modules.""" @@ -210,7 +216,9 @@ def _loss_impl( ) -> Dict[str, torch.Tensor]: losses = {} label = batch.labels[label_name] - sample_weight = batch.sample_weights[self._sample_weight] if self._sample_weight else 1.0 + sample_weight = ( + batch.sample_weights[self._sample_weight] if self._sample_weight else 1.0 + ) loss_type = loss_cfg.WhichOneof("loss") loss_name = loss_type + suffix @@ -223,7 +231,10 @@ def _loss_impl( label = _arange_int_label(pred) else: label = _zero_int_label(pred) - losses[loss_name] = torch.mean(self._loss_modules[loss_name](pred, label) * sample_weight) + losses[loss_name] = self._loss_modules[loss_name](pred, label) + if self._sample_weight: + losses[loss_name] = torch.mean(losses[loss_name] * sample_weight) + return losses def loss( diff --git a/tzrec/models/mmoe.py b/tzrec/models/mmoe.py index e6f472e..c7ed5d2 100644 --- a/tzrec/models/mmoe.py +++ b/tzrec/models/mmoe.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -34,9 +34,14 @@ class MMoE(MultiTaskRank): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.group_name = self.embedding_group.group_names()[0] diff --git a/tzrec/models/model.py b/tzrec/models/model.py index 28dd883..ab7024d 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -44,9 +44,14 @@ class BaseModel(nn.Module, metaclass=_meta_cls): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__() + super().__init__(**kwargs) self._base_model_config = model_config self._model_type = model_config.WhichOneof("model") self._features = features diff --git a/tzrec/models/multi_task_rank.py b/tzrec/models/multi_task_rank.py index 49e2122..8dfd57f 100644 --- a/tzrec/models/multi_task_rank.py +++ b/tzrec/models/multi_task_rank.py @@ -30,9 +30,14 @@ class MultiTaskRank(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self._task_tower_cfgs = list(self._model_config.task_towers) def _multi_task_output_to_prediction( diff --git a/tzrec/models/multi_tower.py b/tzrec/models/multi_tower.py index ae71070..76ec3dd 100644 --- a/tzrec/models/multi_tower.py +++ b/tzrec/models/multi_tower.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -33,9 +33,14 @@ class MultiTower(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.towers = nn.ModuleDict() diff --git a/tzrec/models/multi_tower_din.py b/tzrec/models/multi_tower_din.py index 3bec000..b5d245b 100644 --- a/tzrec/models/multi_tower_din.py +++ b/tzrec/models/multi_tower_din.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -34,9 +34,14 @@ class MultiTowerDIN(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() self.towers = nn.ModuleDict() diff --git a/tzrec/models/multi_tower_din_trt.py b/tzrec/models/multi_tower_din_trt.py index eb1b6ec..62051f1 100644 --- a/tzrec/models/multi_tower_din_trt.py +++ b/tzrec/models/multi_tower_din_trt.py @@ -10,7 +10,7 @@ # limitations under the License. # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -54,9 +54,10 @@ def __init__( model_config: ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = [] + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.grouped_features_keys = embedding_group.grouped_features_keys() @@ -129,9 +130,14 @@ class MultiTowerDINTRT(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) diff --git a/tzrec/models/ple.py b/tzrec/models/ple.py index 3e2d3e1..0582199 100644 --- a/tzrec/models/ple.py +++ b/tzrec/models/ple.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -35,9 +35,14 @@ class PLE(MultiTaskRank): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) assert model_config.WhichOneof("model") == "ple", ( "invalid model config: %s" % self._model_config.WhichOneof("model") ) diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index c16f125..1c71317 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -9,13 +9,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch import torchmetrics from torch import nn -from tzrec.datasets.utils import BASE_DATA_GROUP, Batch, Optional +from tzrec.datasets.utils import BASE_DATA_GROUP, Batch from tzrec.features.feature import BaseFeature from tzrec.loss.jrc_loss import JRCLoss from tzrec.metrics.grouped_auc import GroupedAUC @@ -50,9 +50,10 @@ def __init__( model_config: model_pb2.ModelConfig, features: List[BaseFeature], labels: List[str], - sample_weights: List[str] = [] + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class self._label_name = labels[0] self._loss_collection = {} diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index f11004c..5c8bb63 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List +from typing import Dict, List, Optional import torch from torch import nn @@ -36,9 +36,14 @@ class TDM(RankModel): """ def __init__( - self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + self, + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, + **kwargs, ) -> None: - super().__init__(model_config, features, labels, sample_weights) + super().__init__(model_config, features, labels, sample_weights, **kwargs) self.embedding_group = EmbeddingGroup( features, list(model_config.feature_groups) ) From 42038fda2d6d71d51b02cf0800f8766db1f9b134 Mon Sep 17 00:00:00 2001 From: gecheng Date: Wed, 4 Dec 2024 17:52:55 +0800 Subject: [PATCH 4/5] code style fix --- tzrec/main.py | 8 ++++++-- tzrec/models/dbmtl.py | 4 ++-- tzrec/models/deepfm.py | 4 ++-- tzrec/models/dssm.py | 4 ++-- tzrec/models/dssm_v2.py | 4 ++-- tzrec/models/match_model.py | 4 ++-- tzrec/models/mmoe.py | 4 ++-- tzrec/models/model.py | 4 ++-- tzrec/models/multi_task_rank.py | 4 ++-- tzrec/models/multi_tower.py | 4 ++-- tzrec/models/multi_tower_din.py | 4 ++-- tzrec/models/multi_tower_din_trt.py | 6 +++--- tzrec/models/ple.py | 4 ++-- tzrec/models/rank_model.py | 4 ++-- tzrec/models/tdm.py | 4 ++-- 15 files changed, 35 insertions(+), 31 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index db73f04..510b174 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -218,7 +218,10 @@ def _get_dataloader( def _create_model( - model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = [] + model_config: ModelConfig, + features: List[BaseFeature], + labels: List[str], + sample_weights: Optional[List[str]] = None, ) -> BaseModel: """Build model. @@ -226,6 +229,7 @@ def _create_model( model_config (ModelConfig): easyrec model config. features (list): list of features. labels (list): list of label names. + sample_weights (list): list of sample weight names. Return: model: a EasyRec Model. @@ -538,7 +542,7 @@ def train_and_evaluate( pipeline_config.model_config, features, list(data_config.label_fields), - list(data_config.sample_weight_fields) + list(data_config.sample_weight_fields), ) model = TrainWrapper(model) diff --git a/tzrec/models/dbmtl.py b/tzrec/models/dbmtl.py index e850120..0100244 100644 --- a/tzrec/models/dbmtl.py +++ b/tzrec/models/dbmtl.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -40,7 +40,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) assert model_config.WhichOneof("model") == "dbmtl", ( diff --git a/tzrec/models/deepfm.py b/tzrec/models/deepfm.py index 3874046..12b05d2 100644 --- a/tzrec/models/deepfm.py +++ b/tzrec/models/deepfm.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -39,7 +39,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self.init_input() diff --git a/tzrec/models/dssm.py b/tzrec/models/dssm.py index 247f35e..47dbc3c 100644 --- a/tzrec/models/dssm.py +++ b/tzrec/models/dssm.py @@ -10,7 +10,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -100,7 +100,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} diff --git a/tzrec/models/dssm_v2.py b/tzrec/models/dssm_v2.py index 4b9c806..ff54f7a 100644 --- a/tzrec/models/dssm_v2.py +++ b/tzrec/models/dssm_v2.py @@ -10,7 +10,7 @@ # limitations under the License. from collections import OrderedDict -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torch.nn.functional as F @@ -88,7 +88,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) name_to_feature_group = {x.group_name: x for x in model_config.feature_groups} diff --git a/tzrec/models/match_model.py b/tzrec/models/match_model.py index 5217565..8530beb 100644 --- a/tzrec/models/match_model.py +++ b/tzrec/models/match_model.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -164,7 +164,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class diff --git a/tzrec/models/mmoe.py b/tzrec/models/mmoe.py index c7ed5d2..146e5a7 100644 --- a/tzrec/models/mmoe.py +++ b/tzrec/models/mmoe.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -39,7 +39,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) diff --git a/tzrec/models/model.py b/tzrec/models/model.py index ab7024d..463dae9 100644 --- a/tzrec/models/model.py +++ b/tzrec/models/model.py @@ -12,7 +12,7 @@ # Copyright (c) Alibaba, Inc. and its affiliates. from itertools import chain from queue import Queue -from typing import Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple import torch import torchmetrics @@ -49,7 +49,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(**kwargs) self._base_model_config = model_config diff --git a/tzrec/models/multi_task_rank.py b/tzrec/models/multi_task_rank.py index 8dfd57f..b06fac1 100644 --- a/tzrec/models/multi_task_rank.py +++ b/tzrec/models/multi_task_rank.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch @@ -35,7 +35,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self._task_tower_cfgs = list(self._model_config.task_towers) diff --git a/tzrec/models/multi_tower.py b/tzrec/models/multi_tower.py index 76ec3dd..69f009b 100644 --- a/tzrec/models/multi_tower.py +++ b/tzrec/models/multi_tower.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -38,7 +38,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) diff --git a/tzrec/models/multi_tower_din.py b/tzrec/models/multi_tower_din.py index b5d245b..14ef43e 100644 --- a/tzrec/models/multi_tower_din.py +++ b/tzrec/models/multi_tower_din.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -39,7 +39,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) diff --git a/tzrec/models/multi_tower_din_trt.py b/tzrec/models/multi_tower_din_trt.py index 62051f1..a16476a 100644 --- a/tzrec/models/multi_tower_din_trt.py +++ b/tzrec/models/multi_tower_din_trt.py @@ -10,7 +10,7 @@ # limitations under the License. # Copyright (c) Alibaba, Inc. and its affiliates. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -55,7 +55,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) @@ -135,7 +135,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self.embedding_group = EmbeddingGroup( diff --git a/tzrec/models/ple.py b/tzrec/models/ple.py index 0582199..9026012 100644 --- a/tzrec/models/ple.py +++ b/tzrec/models/ple.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -40,7 +40,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) assert model_config.WhichOneof("model") == "ple", ( diff --git a/tzrec/models/rank_model.py b/tzrec/models/rank_model.py index 1c71317..830125c 100644 --- a/tzrec/models/rank_model.py +++ b/tzrec/models/rank_model.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch import torchmetrics @@ -51,7 +51,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self._num_class = model_config.num_class diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index 5c8bb63..cdbe4f5 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional import torch from torch import nn @@ -41,7 +41,7 @@ def __init__( features: List[BaseFeature], labels: List[str], sample_weights: Optional[List[str]] = None, - **kwargs, + **kwargs: Any, ) -> None: super().__init__(model_config, features, labels, sample_weights, **kwargs) self.embedding_group = EmbeddingGroup( From 28d2c0a369e16ffa6b5ba45f5ba8e9a36be5de04 Mon Sep 17 00:00:00 2001 From: gecheng Date: Wed, 4 Dec 2024 18:46:59 +0800 Subject: [PATCH 5/5] code style fix --- .gitignore | 2 +- tzrec/datasets/data_parser.py | 4 ++-- tzrec/datasets/utils.py | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index e417f22..9a9dc5a 100644 --- a/.gitignore +++ b/.gitignore @@ -39,4 +39,4 @@ protoc* docs/source/intro.md docs/source/proto.html -.vscode/ \ No newline at end of file +.vscode/ diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index db16995..8c8f60e 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -155,7 +155,7 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: for label_name in self._labels: output_data[label_name] = _to_tensor(input_data[label_name].to_numpy()) - + for weight in self._sample_weights: output_data[weight] = _to_tensor(input_data[weight].to_numpy()) @@ -326,7 +326,7 @@ def to_batch( labels = {} for label_name in self._labels: labels[label_name] = input_data[label_name] - + sample_weights = {} for weight in self._sample_weights: sample_weights[weight] = input_data[weight] diff --git a/tzrec/datasets/utils.py b/tzrec/datasets/utils.py index 3fa5d1d..2fec937 100644 --- a/tzrec/datasets/utils.py +++ b/tzrec/datasets/utils.py @@ -136,7 +136,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch": sample_weights={ k: v.to(device=device, non_blocking=non_blocking) for k, v in self.sample_weights.items() - } + }, ) def record_stream(self, stream: torch.Stream) -> None: