Skip to content

Commit

Permalink
bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
eric-gecheng committed Dec 4, 2024
1 parent 2cc79c7 commit dcce8da
Show file tree
Hide file tree
Showing 283 changed files with 87 additions and 35 deletions.
Empty file modified .github/workflows/codestyle_ci.yml
100644 → 100755
Empty file.
Empty file modified .github/workflows/pytyping_ci.yml
100644 → 100755
Empty file.
Empty file modified .github/workflows/unittest_ci.yml
100644 → 100755
Empty file.
Empty file modified .github/workflows/unittest_cpu_ci.yml
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion .gitignore
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ protoc*
docs/source/intro.md
docs/source/proto.html

.vscode/
.vscode/
Empty file modified .pre-commit-config.yaml
100644 → 100755
Empty file.
Empty file modified .pyre_configuration
100644 → 100755
Empty file.
Empty file modified .readthedocs.yaml
100644 → 100755
Empty file.
Empty file modified .ruff.toml
100644 → 100755
Empty file.
Empty file modified LICENSE
100644 → 100755
Empty file.
Empty file modified MANIFEST.in
100644 → 100755
Empty file.
Empty file modified README.md
100644 → 100755
Empty file.
Empty file modified data/.license_header.txt
100644 → 100755
Empty file.
Empty file modified data/test/spiece.model
100644 → 100755
Empty file.
Empty file modified data/test/tokenizer.json
100644 → 100755
Empty file.
Empty file modified docker/Dockerfile
100644 → 100755
Empty file.
Empty file modified docker/pip.conf
100644 → 100755
Empty file.
Empty file modified docs/Makefile
100644 → 100755
Empty file.
Empty file modified docs/images/intro.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/dbmtl.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/dbmtl_mmoe.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/deepfm.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/din.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/dssm_neg_sampler.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/feature_groups_din.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/mmoe.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/multi_tower.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/ple.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/models/tdm.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/qrcode/dinggroup1.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/images/qrcode/dinggroup2.png
100644 → 100755
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Empty file modified docs/make.bat
100644 → 100755
Empty file.
Empty file modified docs/source/conf.py
100644 → 100755
Empty file.
Empty file modified docs/source/develop.md
100644 → 100755
Empty file.
Empty file modified docs/source/faq.md
100644 → 100755
Empty file.
Empty file modified docs/source/feature/data.md
100644 → 100755
Empty file.
Empty file modified docs/source/feature/feature.md
100644 → 100755
Empty file.
Empty file modified docs/source/index.rst
100644 → 100755
Empty file.
Empty file modified docs/source/models/dbmtl.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/deepfm.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/din.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/dssm.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/feature_group.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/loss.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/mmoe.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/multi_target.rst
100644 → 100755
Empty file.
Empty file modified docs/source/models/multi_tower.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/ple.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/rank.rst
100644 → 100755
Empty file.
Empty file modified docs/source/models/recall.rst
100644 → 100755
Empty file.
Empty file modified docs/source/models/tdm.md
100644 → 100755
Empty file.
Empty file modified docs/source/models/user_define.md
100644 → 100755
Empty file.
Empty file modified docs/source/quick_start.rst
100644 → 100755
Empty file.
Empty file modified docs/source/quick_start/dlc_tutorial.md
100644 → 100755
Empty file.
Empty file modified docs/source/quick_start/local_tutorial.md
100644 → 100755
Empty file.
Empty file modified docs/source/quick_start/local_tutorial_tdm.md
100644 → 100755
Empty file.
Empty file modified docs/source/quick_start/local_tutorial_u2i_vec.md
100644 → 100755
Empty file.
Empty file modified docs/source/reference.md
100644 → 100755
Empty file.
Empty file modified docs/source/usage/convert_easyrec_config_to_tzrec_config.md
100644 → 100755
Empty file.
Empty file modified docs/source/usage/eval.md
100644 → 100755
Empty file.
Empty file modified docs/source/usage/export.md
100644 → 100755
Empty file.
Empty file modified docs/source/usage/feature_selection.md
100644 → 100755
Empty file.
Empty file modified docs/source/usage/predict.md
100644 → 100755
Empty file.
Empty file modified docs/source/usage/serving.md
100644 → 100755
Empty file.
Empty file modified docs/source/usage/train.md
100644 → 100755
Empty file.
Empty file modified examples/dbmtl_taobao.config
100644 → 100755
Empty file.
Empty file modified examples/dbmtl_taobao_jrc.config
100644 → 100755
Empty file.
Empty file modified examples/dbmtl_taobao_seq.config
100644 → 100755
Empty file.
Empty file modified examples/deepfm_criteo.config
100644 → 100755
Empty file.
Empty file modified examples/dssm_taobao.config
100644 → 100755
Empty file.
Empty file modified examples/dssm_v2_taobao.config
100644 → 100755
Empty file.
Empty file modified examples/mmoe_taobao.config
100644 → 100755
Empty file.
Empty file modified examples/multi_tower_din_taobao.config
100644 → 100755
Empty file.
Empty file modified examples/multi_tower_taobao.config
100644 → 100755
Empty file.
Empty file modified examples/ple_taobao.config
100644 → 100755
Empty file.
Empty file modified examples/tdm_taobao.config
100644 → 100755
Empty file.
Empty file modified requirements.txt
100644 → 100755
Empty file.
Empty file modified requirements/docs.txt
100644 → 100755
Empty file.
Empty file modified requirements/runtime.txt
100644 → 100755
Empty file.
Empty file modified requirements/test.txt
100644 → 100755
Empty file.
Empty file modified scripts/build_docker.sh
100644 → 100755
Empty file.
Empty file modified scripts/build_wheel.sh
100644 → 100755
Empty file.
Empty file modified scripts/ci_test.sh
100644 → 100755
Empty file.
Empty file modified scripts/doc/build_doc_pre_work.sh
100644 → 100755
Empty file.
Empty file modified scripts/doc/build_docs.sh
100644 → 100755
Empty file.
Empty file modified scripts/gen_proto.sh
100644 → 100755
Empty file.
Empty file modified scripts/pyre_check.py
100644 → 100755
Empty file.
Empty file modified setup.cfg
100644 → 100755
Empty file.
Empty file modified setup.py
100644 → 100755
Empty file.
Empty file modified tzrec/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/acc/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/acc/_aten_lowering_pass.py
100644 → 100755
Empty file.
Empty file modified tzrec/acc/_decompositions.py
100644 → 100755
Empty file.
Empty file modified tzrec/acc/trt_utils.py
100644 → 100755
Empty file.
Empty file modified tzrec/acc/utils.py
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/benchmark.py
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/base_eval_metric.json
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/criteo/deepfm.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao/dbmtl.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao/dbmtl_has_sequence.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao/dbmtl_jrc.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao/mmoe.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao/mmoe_has_sequence.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao/ple.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao/ple_has_sequence.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao_ccp/dbmtl.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao_ccp/mmoe.config
100644 → 100755
Empty file.
Empty file modified tzrec/benchmark/configs/taobao_ccp/ple.config
100644 → 100755
Empty file.
Empty file modified tzrec/constant.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/csv_dataset.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/csv_dataset_test.py
100644 → 100755
Empty file.
4 changes: 2 additions & 2 deletions tzrec/datasets/data_parser.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:

for label_name in self._labels:
output_data[label_name] = _to_tensor(input_data[label_name].to_numpy())

for weight in self._sample_weights:
output_data[weight] = _to_tensor(input_data[weight].to_numpy())

Expand Down Expand Up @@ -326,7 +326,7 @@ def to_batch(
labels = {}
for label_name in self._labels:
labels[label_name] = input_data[label_name]

sample_weights = {}
for weight in self._sample_weights:
sample_weights[weight] = input_data[weight]
Expand Down
Empty file modified tzrec/datasets/data_parser_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/dataset.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/dataset_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/odps_dataset.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/odps_dataset_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/odps_dataset_v1.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/odps_dataset_v1_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/parquet_dataset.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/parquet_dataset_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/sampler.py
100644 → 100755
Empty file.
Empty file modified tzrec/datasets/sampler_test.py
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion tzrec/datasets/utils.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
sample_weights={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sample_weights.items()
}
},
)

def record_stream(self, stream: torch.Stream) -> None:
Expand Down
Empty file modified tzrec/eval.py
100644 → 100755
Empty file.
Empty file modified tzrec/export.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/combo_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/combo_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/expr_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/expr_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/id_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/id_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/lookup_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/lookup_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/match_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/match_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/overlap_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/overlap_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/raw_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/raw_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/sequence_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/sequence_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/tokenize_feature.py
100644 → 100755
Empty file.
Empty file modified tzrec/features/tokenize_feature_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/loss/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/loss/jrc_loss.py
100644 → 100755
Empty file.
Empty file modified tzrec/loss/jrc_loss_test.py
100644 → 100755
Empty file.
8 changes: 6 additions & 2 deletions tzrec/main.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -218,14 +218,18 @@ def _get_dataloader(


def _create_model(
model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> BaseModel:
"""Build model.
Args:
model_config (ModelConfig): easyrec model config.
features (list): list of features.
labels (list): list of label names.
sample_weights (list): list of sample weight names
Return:
model: a EasyRec Model.
Expand Down Expand Up @@ -538,7 +542,7 @@ def train_and_evaluate(
pipeline_config.model_config,
features,
list(data_config.label_fields),
list(data_config.sample_weight_fields)
list(data_config.sample_weight_fields),
)
model = TrainWrapper(model)

Expand Down
Empty file modified tzrec/metrics/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/metrics/grouped_auc.py
100644 → 100755
Empty file.
Empty file modified tzrec/metrics/grouped_auc_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/metrics/recall_at_k.py
100644 → 100755
Empty file.
Empty file modified tzrec/metrics/recall_at_k_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/models/__init__.py
100644 → 100755
Empty file.
8 changes: 6 additions & 2 deletions tzrec/models/dbmtl.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand All @@ -35,7 +35,11 @@ class DBMTL(MultiTaskRank):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
assert model_config.WhichOneof("model") == "dbmtl", (
Expand Down
Empty file modified tzrec/models/dbmtl_test.py
100644 → 100755
Empty file.
8 changes: 6 additions & 2 deletions tzrec/models/deepfm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand All @@ -34,7 +34,11 @@ class DeepFM(RankModel):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
self.init_input()
Expand Down
Empty file modified tzrec/models/deepfm_test.py
100644 → 100755
Empty file.
2 changes: 1 addition & 1 deletion tzrec/models/dssm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def __init__(
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: List[str] = []
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}
Expand Down
Empty file modified tzrec/models/dssm_test.py
100644 → 100755
Empty file.
4 changes: 2 additions & 2 deletions tzrec/models/dssm_v2.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

from collections import OrderedDict
from typing import Dict, List
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: List[str] = []
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
name_to_feature_group = {x.group_name: x for x in model_config.feature_groups}
Expand Down
Empty file modified tzrec/models/dssm_v2_test.py
100644 → 100755
Empty file.
16 changes: 12 additions & 4 deletions tzrec/models/match_model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,11 @@ class MatchModel(BaseModel):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
self._num_class = model_config.num_class
Expand Down Expand Up @@ -190,7 +194,7 @@ def _init_loss_impl(self, loss_cfg: LossConfig, suffix: str = "") -> None:
assert (
loss_type == "softmax_cross_entropy"
), "match model only support softmax_cross_entropy loss now."
self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction='none')
self._loss_modules[loss_name] = nn.CrossEntropyLoss(reduction="none")

def init_loss(self) -> None:
"""Initialize loss modules."""
Expand All @@ -210,7 +214,9 @@ def _loss_impl(
) -> Dict[str, torch.Tensor]:
losses = {}
label = batch.labels[label_name]
sample_weight = batch.sample_weights[self._sample_weight] if self._sample_weight else 1.0
sample_weight = (
batch.sample_weights[self._sample_weight] if self._sample_weight else 1.0
)

loss_type = loss_cfg.WhichOneof("loss")
loss_name = loss_type + suffix
Expand All @@ -223,7 +229,9 @@ def _loss_impl(
label = _arange_int_label(pred)
else:
label = _zero_int_label(pred)
losses[loss_name] = torch.mean(self._loss_modules[loss_name](pred, label) * sample_weight)
losses[loss_name] = torch.mean(
self._loss_modules[loss_name](pred, label) * sample_weight
)
return losses

def loss(
Expand Down
Empty file modified tzrec/models/match_model_test.py
100644 → 100755
Empty file.
8 changes: 6 additions & 2 deletions tzrec/models/mmoe.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand All @@ -34,7 +34,11 @@ class MMoE(MultiTaskRank):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)

Expand Down
Empty file modified tzrec/models/mmoe_test.py
100644 → 100755
Empty file.
6 changes: 5 additions & 1 deletion tzrec/models/model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,11 @@ class BaseModel(nn.Module, metaclass=_meta_cls):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__()
self._base_model_config = model_config
Expand Down
6 changes: 5 additions & 1 deletion tzrec/models/multi_task_rank.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,11 @@ class MultiTaskRank(RankModel):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
self._task_tower_cfgs = list(self._model_config.task_towers)
Expand Down
Empty file modified tzrec/models/multi_task_rank_test.py
100644 → 100755
Empty file.
8 changes: 6 additions & 2 deletions tzrec/models/multi_tower.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand All @@ -33,7 +33,11 @@ class MultiTower(RankModel):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)

Expand Down
8 changes: 6 additions & 2 deletions tzrec/models/multi_tower_din.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand All @@ -34,7 +34,11 @@ class MultiTowerDIN(RankModel):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)

Expand Down
Empty file modified tzrec/models/multi_tower_din_test.py
100644 → 100755
Empty file.
10 changes: 7 additions & 3 deletions tzrec/models/multi_tower_din_trt.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# limitations under the License.

# Copyright (c) Alibaba, Inc. and its affiliates.
from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand Down Expand Up @@ -54,7 +54,7 @@ def __init__(
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: List[str] = []
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)

Expand Down Expand Up @@ -129,7 +129,11 @@ class MultiTowerDINTRT(RankModel):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
self.embedding_group = EmbeddingGroup(
Expand Down
Empty file modified tzrec/models/multi_tower_test.py
100644 → 100755
Empty file.
8 changes: 6 additions & 2 deletions tzrec/models/ple.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand All @@ -35,7 +35,11 @@ class PLE(MultiTaskRank):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
assert model_config.WhichOneof("model") == "ple", (
Expand Down
Empty file modified tzrec/models/ple_test.py
100644 → 100755
Empty file.
6 changes: 3 additions & 3 deletions tzrec/models/rank_model.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
import torchmetrics
from torch import nn

from tzrec.datasets.utils import BASE_DATA_GROUP, Batch, Optional
from tzrec.datasets.utils import BASE_DATA_GROUP, Batch
from tzrec.features.feature import BaseFeature
from tzrec.loss.jrc_loss import JRCLoss
from tzrec.metrics.grouped_auc import GroupedAUC
Expand Down Expand Up @@ -50,7 +50,7 @@ def __init__(
model_config: model_pb2.ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: List[str] = []
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
self._num_class = model_config.num_class
Expand Down
Empty file modified tzrec/models/rank_model_test.py
100644 → 100755
Empty file.
8 changes: 6 additions & 2 deletions tzrec/models/tdm.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional

import torch
from torch import nn
Expand All @@ -36,7 +36,11 @@ class TDM(RankModel):
"""

def __init__(
self, model_config: ModelConfig, features: List[BaseFeature], labels: List[str], sample_weights: List[str] = []
self,
model_config: ModelConfig,
features: List[BaseFeature],
labels: List[str],
sample_weights: Optional[List[str]] = None,
) -> None:
super().__init__(model_config, features, labels, sample_weights)
self.embedding_group = EmbeddingGroup(
Expand Down
Empty file modified tzrec/models/tdm_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/embedding.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/embedding_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/extraction_net.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/extraction_net_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/fm.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/fm_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/interaction.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/interaction_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/mlp.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/mlp_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/mmoe.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/mmoe_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/sequence.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/sequence_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/task_tower.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/task_tower_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/variational_dropout.py
100644 → 100755
Empty file.
Empty file modified tzrec/modules/variational_dropout_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/optim/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/optim/lr_scheduler.py
100644 → 100755
Empty file.
Empty file modified tzrec/optim/lr_scheduler_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/optim/optimizer_builder.py
100644 → 100755
Empty file.
Empty file modified tzrec/predict.py
100644 → 100755
Empty file.
Empty file modified tzrec/protos/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/protos/data.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/eval.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/export.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/feature.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/loss.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/metric.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/model.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/models/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/protos/models/match_model.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/models/multi_task_rank.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/models/rank_model.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/module.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/optimizer.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/pipeline.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/sampler.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/seq_encoder.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/tower.proto
100644 → 100755
Empty file.
Empty file modified tzrec/protos/train.proto
100644 → 100755
Empty file.
Empty file modified tzrec/tests/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/dbmtl_has_sequence_mock.config
100644 → 100755
Empty file.
Empty file.
Empty file modified tzrec/tests/configs/dssm_fg_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/dssm_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/dssm_v2_fg_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/dssm_variational_dropout_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/multi_tower_din_fg_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/multi_tower_din_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/multi_tower_din_trt_fg_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/configs/tdm_fg_mock.config
100644 → 100755
Empty file.
Empty file modified tzrec/tests/run.py
100644 → 100755
Empty file.
Empty file modified tzrec/tests/train_eval_export_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/tests/utils.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/add_feature_info_to_config.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/add_feature_info_to_config_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/convert_easyrec_config_to_tzrec_config.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/convert_easyrec_config_to_tzrec_config_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/create_faiss_index.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/create_fg_json.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/create_online_infer_data.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/feature_selection.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/hitrate.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/list_distcp_param.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/cluster_tree.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_builder.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_builder_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_cluster.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_cluster_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_generator.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_generator_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_search_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/gen_tree/tree_search_util_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/init_tree.py
100644 → 100755
Empty file.
Empty file modified tzrec/tools/tdm/retrieval.py
100644 → 100755
Empty file.
Empty file modified tzrec/train_eval.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/__init__.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/checkpoint_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/checkpoint_util_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/config_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/config_util_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/dist_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/faiss_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/faiss_util_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/load_class.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/load_class_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/logging_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/misc_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/plan_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/plan_util_test.py
100644 → 100755
Empty file.
Empty file modified tzrec/utils/test_util.py
100644 → 100755
Empty file.
Empty file modified tzrec/version.py
100644 → 100755
Empty file.

0 comments on commit dcce8da

Please sign in to comment.