From 7b331c22d2e333c73ad8e72df7e354dbdfbd095a Mon Sep 17 00:00:00 2001 From: chengaofei Date: Thu, 21 Nov 2024 16:19:40 +0800 Subject: [PATCH] [feat] support easyrec config convert to tzrec config (#37) --- .pyre_configuration | 4 +- .ruff.toml | 1 + docs/source/index.rst | 1 + .../convert_easyrec_config_to_tzrec_config.md | 30 + tzrec/constant.py | 3 + .../convert_easyrec_config_to_tzrec_config.py | 661 ++++++++++++++++ ...easyrec_config_to_tzrec_config_test_tmp.py | 736 ++++++++++++++++++ 7 files changed, 1435 insertions(+), 1 deletion(-) create mode 100644 docs/source/usage/convert_easyrec_config_to_tzrec_config.md create mode 100644 tzrec/tools/convert_easyrec_config_to_tzrec_config.py create mode 100644 tzrec/tools/convert_easyrec_config_to_tzrec_config_test_tmp.py diff --git a/.pyre_configuration b/.pyre_configuration index 6c2f954..b0a2088 100644 --- a/.pyre_configuration +++ b/.pyre_configuration @@ -4,7 +4,9 @@ "tzrec/*/*_test.py", "tzrec/tests/*.py", "tzrec/utils/load_class.py", - "tzrec/acc/_*.py" + "tzrec/acc/_*.py", + "tzrec/tools/convert_easyrec_config_to_tzrec_config.py", + "tzrec/*/*_test_tmp.py" ], "site_package_search_strategy": "all", "source_directories": [ diff --git a/.ruff.toml b/.ruff.toml index 90a3b90..3665ead 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -3,6 +3,7 @@ lint.ignore = ["D100", "D104", "D105", "D107"] [lint.per-file-ignores] "*_test.py" = ["D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107"] +"*_test_tmp.py" = ["D100", "D101", "D102", "D103", "D104", "D105", "D106", "D107"] [lint.pydocstyle] convention = "google" diff --git a/docs/source/index.rst b/docs/source/index.rst index 46e1e70..a643e2f 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -36,6 +36,7 @@ Welcome to TorchEasyRec's documentation! usage/predict usage/serving usage/feature_selection + usage/convert_easyrec_config_to_tzrec_config .. toctree:: :maxdepth: 1 diff --git a/docs/source/usage/convert_easyrec_config_to_tzrec_config.md b/docs/source/usage/convert_easyrec_config_to_tzrec_config.md new file mode 100644 index 0000000..ae549c0 --- /dev/null +++ b/docs/source/usage/convert_easyrec_config_to_tzrec_config.md @@ -0,0 +1,30 @@ +# EasyRec迁移TorchEasyRec + +推荐模型一般特征和模型配置较为复杂,TorchEasyRec提供了配置转换工具,可以便捷地将EasyRec的配置文件转换为TorchEasyRec文件。 + +## 转换命令 + +torcheasyrec的pipeline.config包含了feature generate的配置,因此需要有easyrec训练使用的pipeline.config和fg.json两部分才可以转换为torcheasyrec的pipeline.config + +```bash +PYTHONPATH=. python tzrec/tools/convert_easyrec_config_to_tzrec_config.py \ + --easyrec_config_path ./easyrec.config \ + --fg_json_path ./fg.json \ + --output_tzrec_config_path ./tzrec.config +``` + +- --easyrec_config_path: easyrec训练使用的pipeline.config路径 +- --fg_json_path: easyrec训练和推理使用的fg.json路径 +- --output_tzrec_config_path: 生成tzrec的config路径 + +如果使用自定义的EasyRec安装包,应使用如下转换命令 + +```bash +EASYREC_URL=http://xxx.whl \ +PYTHONPATH=. python tzrec/tools/convert_easyrec_config_to_tzrec_config.py \ + --easyrec_config_path ./easyrec.config \ + --fg_json_path ./fg.json \ + --output_tzrec_config_path ./tzrec.config +``` + +- EASYREC_URL: 是http开头的EasyRec tar包或者whl的url地址 diff --git a/tzrec/constant.py b/tzrec/constant.py index 714c506..8ebd6b4 100644 --- a/tzrec/constant.py +++ b/tzrec/constant.py @@ -19,3 +19,6 @@ class Mode(Enum): TRAIN = 1 EVAL = 2 PREDICT = 3 + + +EASYREC_VERSION = "0.7.5" diff --git a/tzrec/tools/convert_easyrec_config_to_tzrec_config.py b/tzrec/tools/convert_easyrec_config_to_tzrec_config.py new file mode 100644 index 0000000..1463b8e --- /dev/null +++ b/tzrec/tools/convert_easyrec_config_to_tzrec_config.py @@ -0,0 +1,661 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import argparse +import io +import json +import os +import sys +import tarfile +import tempfile +import zipfile +from collections import OrderedDict + +import requests +from google.protobuf import descriptor_pool, symbol_database, text_format + +from tzrec.constant import EASYREC_VERSION +from tzrec.protos import feature_pb2 as tzrec_feature_pb2 +from tzrec.protos import ( + loss_pb2, + metric_pb2, + model_pb2, + module_pb2, + seq_encoder_pb2, + tower_pb2, +) +from tzrec.protos import pipeline_pb2 as tzrec_pipeline_pb2 +from tzrec.protos.data_pb2 import DatasetType +from tzrec.protos.models import match_model_pb2, multi_task_rank_pb2, rank_model_pb2 +from tzrec.utils.logging_util import logger + + +def _get_easyrec(local_cache_dir): + """Get easyrec whl and extract.""" + whl_path = os.environ.get("EASYREC_URL") + if whl_path is None: + whl_path = ( + f"https://easyrec.oss-cn-beijing.aliyuncs.com/release/whls/" + f"easy_rec-{EASYREC_VERSION}-py2.py3-none-any.whl" + ) + r = requests.get(whl_path) + logger.info(f"down easyrec from {whl_path}") + if "tar.gz" in whl_path: + try: + with tarfile.open(fileobj=io.BytesIO(r.content)) as tar: + tar.extractall(path=local_cache_dir) + local_package_dir = local_cache_dir + except Exception: + logger.error(f"invalid {EASYREC_VERSION} tar.") + local_package_dir = None + else: + try: + with zipfile.ZipFile(io.BytesIO(r.content)) as f: + f.extractall(local_cache_dir) + local_package_dir = local_cache_dir + except zipfile.BadZipfile: + logger.error(f"invalid {EASYREC_VERSION} whl.") + local_package_dir = None + return local_package_dir + + +try: + import easyrec # noqa: F401 +except ImportError: + local_cache_dir = tempfile.mkdtemp(prefix="tzrec_tmp") + local_package_dir = _get_easyrec(local_cache_dir) + with open(os.path.join(local_package_dir, "easy_rec/__init__.py"), "w") as f: + f.write("") + sys.path.append(local_package_dir) + _sym = symbol_database.Default() + _sym.pool = descriptor_pool.DescriptorPool() +finally: + # os.environ["PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION"] = "python" + from easy_rec.python.protos import pipeline_pb2 as easyrec_pipeline_pb2 + from easy_rec.python.protos.feature_config_pb2 import FeatureConfig, WideOrDeep + from easy_rec.python.protos.loss_pb2 import LossType + + +class ConvertConfig(object): + """Convert EasyRec config to tzrec config. + + Args: + easyrec_config_path (str): EasyRec config file path. + fg_json_path (str): EasyRec use fg.json file path. + output_tzrec_config_path (str): TzRec config file path will create. + """ + + def __init__(self, easyrec_config_path, fg_json_path, output_tzrec_config_path): + self.output_tzrec_config_path = output_tzrec_config_path + self.easyrec_config = self.load_easyrec_config(easyrec_config_path) + self.fg_json = self.load_easyrec_fg_json(fg_json_path) + self.feature_to_fg = {} + self.sub_sequence_to_group = {} + self.sequence_feature_to_fg = {} + self.analyse_fg() + + def analyse_fg(self): + """Analysis fg.json.""" + for feat in self.fg_json["features"]: + if "sequence_name" in feat: + sequence_name = feat["sequence_name"] + for sub_feat in feat["features"]: + self.sub_sequence_to_group[ + f"{sequence_name}__{sub_feat['feature_name']}" + ] = sequence_name + self.sequence_feature_to_fg[sequence_name] = feat + + else: + feature_name = feat["feature_name"] + self.feature_to_fg[feature_name] = feat + + def load_easyrec_config(self, path): + """Load easyrec config.""" + easyrec_config = easyrec_pipeline_pb2.EasyRecConfig() + with open(path, "r", encoding="utf-8") as f: + cfg_str = f.read() + text_format.Merge(cfg_str, easyrec_config) + return easyrec_config + + def load_easyrec_fg_json(self, path): + """Load easyrec use fg.json.""" + with open(path, "r", encoding="utf-8") as f: + fg_json = json.load(f) + return fg_json + + def _create_train_config(self, pipeline_config): + """Create easy_rec train config.""" + if not pipeline_config.HasField("train_config"): + train_config_str = """ + train_config { + sparse_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 + use_tensorboard: false + }""" + text_format.Merge(train_config_str, pipeline_config) + return pipeline_config + + def _create_eval_config(self, pipeline_config): + """Create tzrec train config.""" + if not pipeline_config.HasField("eval_config"): + eval_config_str = "eval_config {}" + text_format.Merge(eval_config_str, pipeline_config) + return pipeline_config + + def _create_data_config(self, pipeline_config): + """Create tzrec data config.""" + label_fields = list(self.easyrec_config.data_config.label_fields) + pipeline_config.data_config.batch_size = ( + self.easyrec_config.data_config.batch_size + ) + pipeline_config.data_config.dataset_type = DatasetType.OdpsDataset + pipeline_config.data_config.fg_encoded = True + pipeline_config.data_config.label_fields.extend(label_fields) + pipeline_config.data_config.num_workers = 8 + pipeline_config.data_config.odps_data_quota_name = "" + return pipeline_config + + def _create_feature_config(self, pipeline_config): + """Create tzrec feature config.""" + easyrec_feature_config = FeatureConfig() + seq_group_cfg = OrderedDict() + for cfg in self.easyrec_config.feature_configs: + if cfg.feature_name: + feature_name = cfg.feature_name + else: + feature_name = list(cfg.input_names)[0] + input_names = cfg.input_names + feature_type = cfg.feature_type + + if feature_name in self.feature_to_fg: + fg_json = self.feature_to_fg[feature_name] + elif feature_name in self.sub_sequence_to_group: + pass + elif input_names[0] in self.feature_to_fg: + fg_json = self.feature_to_fg[input_names[0]] + else: + logger.error(f"in easyrec config {feature_name} not in fg.json") + + feature_config = None + if feature_type == easyrec_feature_config.IdFeature: + feature_config = tzrec_feature_pb2.FeatureConfig() + feature = tzrec_feature_pb2.IdFeature() + feature.feature_name = feature_name + feature.expression = fg_json["expression"] + feature.embedding_dim = cfg.embedding_dim + feature.hash_bucket_size = cfg.hash_bucket_size + feature_config.ClearField("feature") + feature_config.id_feature.CopyFrom(feature) + elif feature_type == easyrec_feature_config.TagFeature: + feature_config = tzrec_feature_pb2.FeatureConfig() + feature = tzrec_feature_pb2.IdFeature() + feature.feature_name = feature_name + feature.expression = fg_json["expression"] + feature.embedding_dim = cfg.embedding_dim + feature.hash_bucket_size = cfg.hash_bucket_size + if cfg.HasField("kv_separator"): + feature.weighted = True + feature_config.ClearField("feature") + feature_config.id_feature.CopyFrom(feature) + elif feature_type == easyrec_feature_config.SequenceFeature: + if feature_name in self.sub_sequence_to_group: + sequence_name = self.sub_sequence_to_group[feature_name] + if sequence_name in seq_group_cfg: + seq_group_cfg[sequence_name].append(cfg) + else: + seq_group_cfg[sequence_name] = [cfg] + elif feature_name in self.feature_to_fg: + feature_config = tzrec_feature_pb2.FeatureConfig() + if cfg.sub_feature_type == easyrec_feature_config.IdFeature: + feature = tzrec_feature_pb2.SequenceIdFeature() + feature.feature_name = feature_name + feature.expression = self.feature_to_fg[feature_name][ + "expression" + ] + feature.embedding_dim = cfg.embedding_dim + feature.hash_bucket_size = cfg.hash_bucket_size + feature_config.ClearField("feature") + feature_config.sequence_id_feature.CopyFrom(feature) + else: + feature = tzrec_feature_pb2.SequenceRawFeature() + feature.feature_name = feature_name + feature.expression = self.feature_to_fg[feature_name][ + "expression" + ] + boundaries = list(cfg.boundaries) + feature.embedding_dim = cfg.embedding_dim + if len(boundaries): + feature.boundaries.extend(boundaries) + feature_config.ClearField("feature") + feature_config.sequence_raw_feature.CopyFrom(feature) + else: + logger.error(f"sequences feature: {feature_name} can't converted") + elif feature_type == easyrec_feature_config.RawFeature: + feature_config = tzrec_feature_pb2.FeatureConfig() + if fg_json["feature_type"] == "lookup_feature": + feature = tzrec_feature_pb2.LookupFeature() + feature.feature_name = feature_name + map = fg_json["map"] + key = fg_json["key"] + boundaries = list(cfg.boundaries) + feature.feature_name = feature_name + feature.map = map + feature.key = key + feature.embedding_dim = cfg.embedding_dim + if len(boundaries): + feature.boundaries.extend(boundaries) + feature_config.ClearField("feature") + feature_config.lookup_feature.CopyFrom(feature) + else: + feature = tzrec_feature_pb2.RawFeature() + feature.feature_name = feature_name + feature.expression = fg_json["expression"] + boundaries = list(cfg.boundaries) + feature.embedding_dim = cfg.embedding_dim + if len(boundaries): + feature.boundaries.extend(boundaries) + feature_config.ClearField("feature") + feature_config.raw_feature.CopyFrom(feature) + elif feature_type == easyrec_feature_config.ComboFeature: + feature_config = tzrec_feature_pb2.FeatureConfig() + feature = tzrec_feature_pb2.ComboFeature() + feature.feature_name = feature_name + for input in list(cfg.input_names): + if input in self.feature_to_fg: + tmp_fg_json = self.feature_to_fg[input] + feature.expression.append(tmp_fg_json["expression"]) + else: + raise ValueError(f"{cfg} input_names:{input} not in fg json") + feature.embedding_dim = cfg.embedding_dim + feature.hash_bucket_size = cfg.hash_bucket_size + feature_config.ClearField("feature") + feature_config.combo_feature.CopyFrom(feature) + elif feature_type == easyrec_feature_config.LookupFeature: + feature_config = tzrec_feature_pb2.FeatureConfig() + feature = tzrec_feature_pb2.LookupFeature() + feature.feature_name = feature_name + map_f = cfg.input_names[0] + key_f = cfg.input_names[1] + if map_f in self.feature_to_fg: + feature.map = self.feature_to_fg[map_f]["expression"] + else: + raise ValueError(f"{cfg} input names: {map_f} not in fg.json") + if key_f in self.feature_to_fg: + feature.key = self.feature_to_fg[key_f]["expression"] + else: + raise ValueError(f"{cfg} input names: {map_f} not in fg.json") + feature.embedding_dim = cfg.embedding_dim + if len(list(cfg.boundaries)): + feature.boundaries.extend(list(cfg.boundaries)) + feature_config.ClearField("feature") + feature_config.lookup_feature.CopyFrom(feature) + else: + logger.error(f"{feature_name} can't converted") + if feature_config is not None: + pipeline_config.feature_configs.append(feature_config) + for seq_name, sub_cfgs in seq_group_cfg.items(): + sequence_fg = self.sequence_feature_to_fg[seq_name] + feature_config = tzrec_feature_pb2.FeatureConfig() + sequence_feature_config = tzrec_feature_pb2.SequenceFeature() + sequence_feature_config.sequence_name = sequence_fg["sequence_name"] + sequence_feature_config.sequence_length = sequence_fg["sequence_length"] + sequence_feature_config.sequence_delim = sequence_fg["sequence_delim"] + features = sequence_fg["features"] + seq_feature_to_fg = {} + for feature in features: + seq_feature_to_fg[f'{seq_name}__{feature["feature_name"]}'] = feature + for cfg in sub_cfgs: + sub_feature_cfg = tzrec_feature_pb2.SeqFeatureConfig() + feature_name = ( + cfg.feature_name if cfg.feature_name else cfg.input_names[0] + ) + if feature_name in seq_feature_to_fg: + seq_feature_fg = seq_feature_to_fg[feature_name] + if cfg.sub_feature_type == easyrec_feature_config.IdFeature: + feature = tzrec_feature_pb2.IdFeature() + feature.feature_name = seq_feature_fg["feature_name"] + feature.expression = seq_feature_fg["expression"] + feature.embedding_dim = cfg.embedding_dim + feature.hash_bucket_size = cfg.hash_bucket_size + sub_feature_cfg.ClearField("feature") + sub_feature_cfg.id_feature.CopyFrom(feature) + else: + feature = tzrec_feature_pb2.RawFeature() + feature.feature_name = seq_feature_fg["feature_name"] + feature.expression = seq_feature_fg["expression"] + boundaries = list(cfg.boundaries) + feature.embedding_dim = cfg.embedding_dim + if len(boundaries): + feature.boundaries.extend(boundaries) + sub_feature_cfg.ClearField("feature") + sub_feature_cfg.raw_feature.CopyFrom(feature) + sequence_feature_config.features.append(sub_feature_cfg) + else: + logger.error( + f"sequence feature: {feature_name} not config in fg.json" + ) + + feature_config.sequence_feature.CopyFrom(sequence_feature_config) + pipeline_config.feature_configs.append(feature_config) + + return pipeline_config + + def _easyrec_dnn_2_tzrec_mlp(self, dnn): + """Convert easyrec dnn to tzrec mlp.""" + mlp = module_pb2.MLP() + mlp.hidden_units.extend(dnn.hidden_units) + mlp.dropout_ratio.extend(dnn.dropout_ratio) + mlp.use_bn = dnn.use_bn + return mlp + + def _easyrec_loss_2_tzrec_loss(self, easyrec_loss): + """Convert easyrec loss to tzrec loss.""" + tzrec_loss = loss_pb2.LossConfig() + loss_type = easyrec_loss.loss_type + if loss_type == LossType.JRC_LOSS: + tzrec_loss.jrc_loss.CopyFrom(loss_pb2.JRCLoss()) + elif loss_type == LossType.L2_LOSS: + tzrec_loss.l2_loss.CopyFrom(loss_pb2.L2Loss()) + elif loss_type == LossType.SOFTMAX_CROSS_ENTROPY: + tzrec_loss.softmax_cross_entropy.CopyFrom(loss_pb2.SoftmaxCrossEntropy()) + elif loss_type == LossType.CLASSIFICATION: + tzrec_loss.binary_cross_entropy.CopyFrom(loss_pb2.BinaryCrossEntropy()) + else: + logger.error( + f"{easyrec_loss} is not convert to tzrec loss, please adaptation" + ) + return tzrec_loss + + def _easyrec_metrics_2_tzrec_metrics(self, easyrec_metric): + """Convert easyrec metric to tzrec metric.""" + metric = metric_pb2.MetricConfig() + metric_type = easyrec_metric.WhichOneof("metric") + easyrec_metric_ob = getattr(easyrec_metric, metric_type) + if metric_type == "auc": + metric.auc.CopyFrom(metric_pb2.AUC()) + elif metric_type == "gauc": + tzrec_metric_ob = metric_pb2.GroupedAUC( + grouping_key=easyrec_metric_ob.uid_field + ) + metric.grouped_auc.CopyFrom(tzrec_metric_ob) + elif metric_type == "recall_at_topk": + metric.recall_at_k.CopyFrom(metric_pb2.RecallAtK()) + elif metric_type == "mean_absolute_error": + metric.mean_absolute_error.CopyFrom(metric_pb2.MeanAbsoluteError()) + elif metric_type == "mean_squared_error": + metric.mean_squared_error.CopyFrom(metric_pb2.MeanSquaredError()) + elif metric_type == "accuracy": + metric.accuracy.CopyFrom(metric_pb2.Accuracy()) + else: + logger.error( + f"{easyrec_metric} is not convert to tzrec metric, please adaptation" + ) + return metric + + def _easyrec_bayes_tower_2_tzrec_bayes_tower(self, easyrec_bayes_task_tower): + """Convert easyrec bayes tower to tzrec bayes tower.""" + tzrec_bayes_task_tower = tower_pb2.BayesTaskTower() + tzrec_bayes_task_tower.tower_name = easyrec_bayes_task_tower.tower_name + tzrec_bayes_task_tower.label_name = easyrec_bayes_task_tower.label_name + tzrec_bayes_task_tower.num_class = easyrec_bayes_task_tower.num_class + tzrec_bayes_task_tower.relation_tower_names.extend( + easyrec_bayes_task_tower.relation_tower_names + ) + mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_bayes_task_tower.dnn) + tzrec_bayes_task_tower.mlp.CopyFrom(mlp) + relation_mlp = self._easyrec_dnn_2_tzrec_mlp( + easyrec_bayes_task_tower.relation_dnn + ) + tzrec_bayes_task_tower.relation_mlp.CopyFrom(relation_mlp) + for loss in easyrec_bayes_task_tower.losses: + tzrec_bayes_task_tower.losses.append(self._easyrec_loss_2_tzrec_loss(loss)) + for metric in easyrec_bayes_task_tower.metrics_set: + tzrec_bayes_task_tower.metrics.append( + self._easyrec_metrics_2_tzrec_metrics(metric) + ) + return tzrec_bayes_task_tower + + def _easyrec_task_tower_2_tzrec_task_tower(self, easyrec_task_tower): + """Convert easyrec task tower to tzrec task tower.""" + tzrec_task_tower = tower_pb2.TaskTower() + tzrec_task_tower.tower_name = easyrec_task_tower.tower_name + tzrec_task_tower.label_name = easyrec_task_tower.label_name + tzrec_task_tower.num_class = easyrec_task_tower.num_class + mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_task_tower.dnn) + tzrec_task_tower.mlp.CopyFrom(mlp) + for loss in easyrec_task_tower.losses: + tzrec_task_tower.losses.append(self._easyrec_loss_2_tzrec_loss(loss)) + for metric in easyrec_task_tower.metrics_set: + tzrec_task_tower.metrics.append( + self._easyrec_metrics_2_tzrec_metrics(metric) + ) + return tzrec_task_tower + + def _easyrec_tower_2_tzrec_tower(self, easyrec_tower): + """Convert easyrec tower to tzrec tower.""" + tzrec_tower = tower_pb2.Tower() + tzrec_tower.input = easyrec_tower.input + mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_tower.dnn) + tzrec_tower.mlp.CopyFrom(mlp) + return tzrec_tower + + def _easyrec_dssm_tower_2_tzrec_tower(self, easyrec_dssm_tower): + """Convert easyrec dssm tower to tzrec tower.""" + tzrec_tower = tower_pb2.Tower() + tzrec_tower.input = easyrec_dssm_tower.id + mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_dssm_tower.dnn) + tzrec_tower.mlp.CopyFrom(mlp) + return tzrec_tower + + def _easyrec_extraction_network_2_tzrec_extraction_network( + self, easyrec_extraction_network + ): + """Convert easyrec extraction net to tzrec extraction net.""" + tzrec_extraction_network = module_pb2.ExtractionNetwork() + tzrec_extraction_network.network_name = easyrec_extraction_network.network_name + tzrec_extraction_network.expert_num_per_task = ( + easyrec_extraction_network.expert_num_per_task + ) + tzrec_extraction_network.share_num = easyrec_extraction_network.share_num + task_expert_net = self._easyrec_dnn_2_tzrec_mlp( + easyrec_extraction_network.task_expert_net + ) + tzrec_extraction_network.task_expert_net = task_expert_net + share_expert_net = self._easyrec_dnn_2_tzrec_mlp( + easyrec_extraction_network.share_expert_net + ) + tzrec_extraction_network.share_expert_net = share_expert_net + return tzrec_extraction_network + + def _convert_model_feature_group(self, easyrec_feature_groups): + """Convert easyrec feature group to tzrec feature group.""" + tz_feature_groups = [] + for easy_feature_group in easyrec_feature_groups: + tz_feature_group = model_pb2.FeatureGroupConfig() + tz_feature_group.group_name = easy_feature_group.group_name + tz_feature_group.feature_names.extend(easy_feature_group.feature_names) + if easy_feature_group.wide_deep == WideOrDeep.WIDE: + tz_feature_group.group_type = model_pb2.FeatureGroupType.WIDE + else: + tz_feature_group.group_type = model_pb2.FeatureGroupType.DEEP + for i, easyrec_sequence_group in enumerate( + easy_feature_group.sequence_features + ): + tz_seq_group = model_pb2.SeqGroupConfig() + tz_seq_encoder = seq_encoder_pb2.SeqEncoderConfig() + seq_encoder = seq_encoder_pb2.DINEncoder() + if easyrec_sequence_group.HasField("group_name"): + group_name = easyrec_sequence_group.group_name + else: + group_name = f"seq_{i}" + tz_seq_group.group_name = group_name + seq_encoder.input = group_name + mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_sequence_group.seq_dnn) + seq_encoder.attn_mlp.CopyFrom(mlp) + tz_seq_encoder.din_encoder.CopyFrom(seq_encoder) + for seq_att_map in easyrec_sequence_group.seq_att_map: + tz_seq_group.feature_names.extend(seq_att_map.key) + tz_seq_group.feature_names.extend(seq_att_map.hist_seq) + tz_seq_group.feature_names.extend(seq_att_map.aux_hist_seq) + tz_feature_group.sequence_groups.append(tz_seq_group) + tz_feature_group.sequence_encoders.append(tz_seq_encoder) + tz_feature_groups.append(tz_feature_group) + return tz_feature_groups + + def _convert_model_config(self, easyrec_model_config, tz_model_config): + """Convert easyrec model config to tzrec model config.""" + model_class = easyrec_model_config.model_class + model_type = easyrec_model_config.WhichOneof("model") + easyrec_model_config = getattr(easyrec_model_config, model_type) + if model_class == "DBMTL": + tz_model_config_ob = multi_task_rank_pb2.DBMTL() + bottom_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.bottom_dnn) + expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn) + tz_model_config_ob.bottom_mlp.CopyFrom(bottom_mlp) + tz_model_config_ob.expert_mlp.CopyFrom(expert_mlp) + tz_model_config_ob.num_expert = easyrec_model_config.num_expert + for task_tower in easyrec_model_config.task_towers: + tz_task_tower = self._easyrec_bayes_tower_2_tzrec_bayes_tower( + task_tower + ) + tz_model_config_ob.task_towers.append(tz_task_tower) + tz_model_config.dbmtl.CopyFrom(tz_model_config_ob) + elif model_class == "SimpleMultiTask": + tz_model_config_ob = multi_task_rank_pb2.SimpleMultiTask() + for task_tower in easyrec_model_config.task_towers: + tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower) + tz_model_config_ob.task_towers.append(tz_task_tower) + tz_model_config.simple_multi_task.CopyFrom(tz_model_config_ob) + elif model_class == "MMoE": + tz_model_config_ob = multi_task_rank_pb2.MMoE() + expert_mlp = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.expert_dnn) + tz_model_config_ob.expert_mlp = expert_mlp + tz_model_config_ob.gate_mlp = expert_mlp + tz_model_config_ob.num_expert = easyrec_model_config.num_expert + for task_tower in easyrec_model_config.task_towers: + tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower) + tz_model_config_ob.task_towers.append(tz_task_tower) + tz_model_config.mmoe.CopyFrom(tz_model_config_ob) + elif model_class == "PLE": + tz_model_config_ob = multi_task_rank_pb2.PLE() + for extraction_network in easyrec_model_config.extraction_networks: + tz_extraction_network = ( + self._easyrec_extraction_network_2_tzrec_extraction_network( + extraction_network + ) + ) + tz_model_config.extraction_networks.append(tz_extraction_network) + for task_tower in easyrec_model_config.task_towers: + tz_task_tower = self._easyrec_task_tower_2_tzrec_task_tower(task_tower) + tz_model_config_ob.task_towers.append(tz_task_tower) + tz_model_config.ple.CopyFrom(tz_model_config_ob) + elif model_class == "DeepFM": + tz_model_config_ob = rank_model_pb2.DeepFM() + deep = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.dnn) + final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn) + tz_model_config_ob.deep = deep + tz_model_config_ob.final = final + if easyrec_model_config.HasField("wide_output_dim"): + tz_model_config_ob.wide_embedding_dim = ( + easyrec_model_config.wide_output_dim + ) + tz_model_config.deepfm.CopyFrom(tz_model_config_ob) + elif model_class == "MultiTower": + tz_model_config_ob = rank_model_pb2.MultiTower() + for tower in easyrec_model_config.towers: + tz_tower = self._easyrec_tower_2_tzrec_tower(tower) + tz_model_config_ob.towers.append(tz_tower) + final = self._easyrec_dnn_2_tzrec_mlp(easyrec_model_config.final_dnn) + tz_model_config_ob.final = final + tz_model_config.multi_tower.CopyFrom(tz_model_config_ob) + elif model_class == "DSSM": + tz_model_config_ob = match_model_pb2.DSSM() + user_tower = self._easyrec_dssm_tower_2_tzrec_tower( + easyrec_model_config.user_tower + ) + tz_model_config_ob.user_tower = user_tower + item_tower = self._easyrec_dssm_tower_2_tzrec_tower( + easyrec_model_config.item_tower + ) + tz_model_config_ob.item_tower = item_tower + tz_model_config_ob.output_dim = 32 + if easyrec_model_config.HasField("temperature"): + tz_model_config_ob.temperature = easyrec_model_config.temperature + tz_model_config.dssm.CopyFrom(tz_model_config_ob) + else: + logger.error( + f"{model_class} is not convert to tzrec model, please adaptation" + ) + return tz_model_config + + def _create_model_config(self, pipeline_config): + """Convert easyrec model config to tzrec model config.""" + tz_model_config = model_pb2.ModelConfig() + easyrec_model_config = self.easyrec_config.model_config + easyrec_feature_groups = easyrec_model_config.feature_groups + tz_feature_groups = self._convert_model_feature_group(easyrec_feature_groups) + tz_model_config.feature_groups.extend(tz_feature_groups) + tz_model_config = self._convert_model_config( + easyrec_model_config, tz_model_config + ) + pipeline_config.model_config.CopyFrom(tz_model_config) + return pipeline_config + + def build(self): + """Create tzrec model config order by easyrec config and fg file.""" + tzrec_config = tzrec_pipeline_pb2.EasyRecConfig() + tzrec_config = self._create_train_config(tzrec_config) + tzrec_config = self._create_eval_config(tzrec_config) + tzrec_config = self._create_data_config(tzrec_config) + tzrec_config = self._create_feature_config(tzrec_config) + tzrec_config = self._create_model_config(tzrec_config) + config_text = text_format.MessageToString(tzrec_config, as_utf8=True) + with open(self.output_tzrec_config_path, "w") as f: + f.write(config_text) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--easyrec_config_path", + type=str, + default=None, + help="easyrec model config path", + ) + parser.add_argument( + "--fg_json_path", type=str, default=None, help="easyrec use fg.json path" + ) + parser.add_argument( + "--output_tzrec_config_path", + type=str, + default=None, + help="output tzrec config path", + ) + args, extra_args = parser.parse_known_args() + fs = ConvertConfig( + args.easyrec_config_path, + args.fg_json_path, + args.output_tzrec_config_path, + ) + fs.build() diff --git a/tzrec/tools/convert_easyrec_config_to_tzrec_config_test_tmp.py b/tzrec/tools/convert_easyrec_config_to_tzrec_config_test_tmp.py new file mode 100644 index 0000000..c9c1c07 --- /dev/null +++ b/tzrec/tools/convert_easyrec_config_to_tzrec_config_test_tmp.py @@ -0,0 +1,736 @@ +# Copyright (c) 2024, Alibaba Group; +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +import shutil +import tempfile +import unittest + +from google.protobuf import text_format + +from tzrec.protos import pipeline_pb2 as tzrec_pipeline_pb2 +from tzrec.tools.convert_easyrec_config_to_tzrec_config import ConvertConfig + +FG_JSON = { + "features": [ + { + "feature_name": "user_id", + "feature_type": "id_feature", + "value_type": "String", + "expression": "user:user_id", + "default_value": "-1024", + "combiner": "mean", + "need_prefix": False, + "is_multi": False, + }, + { + "feature_name": "item_id", + "feature_type": "id_feature", + "value_type": "String", + "expression": "item:item_id", + "default_value": "-1024", + "combiner": "mean", + "need_prefix": False, + "is_multi": False, + }, + { + "feature_name": "user_blue_level", + "feature_type": "id_feature", + "value_type": "String", + "expression": "user:user_blue_level", + "default_value": "-1024", + "combiner": "mean", + "need_prefix": False, + "is_multi": False, + }, + { + "feature_name": "host_price_level", + "feature_type": "id_feature", + "value_type": "String", + "expression": "item:host_price_level", + "default_value": "-1024", + "combiner": "mean", + "need_prefix": False, + "is_multi": False, + }, + { + "feature_name": "user_video_sequence", + "feature_type": "id_feature", + "value_type": "String", + "expression": "user:user_video_sequence", + "default_value": "-1024", + "combiner": "mean", + "need_prefix": False, + "is_multi": False, + }, + { + "feature_name": "item__kv_user_blue_level_exposure_cnt_7d", + "feature_type": "lookup_feature", + "value_type": "Double", + "map": "item:item__kv_user_blue_level_exposure_cnt_7d", + "key": "user:user_blue_level", + "needDiscrete": False, + "needWeighting": False, + "needKey": False, + "default_value": "0", + "combiner": "mean", + "need_prefix": False, + }, + { + "feature_name": "item__kv_user_blue_level_click_focus_cnt_7d", + "feature_type": "id_feature", + "value_type": "Double", + "expression": "item:item__kv_user_blue_level_click_focus_cnt_7d", + "default_value": "", + "combiner": "mean", + "need_prefix": False, + }, + { + "feature_name": "item__kv_user_blue_level_click_video_div_exposure_cnt_30d", + "feature_type": "lookup_feature", + "value_type": "Double", + "map": "item:item__kv_user_blue_level_click_video_div_exposure_cnt_30d", + "key": "user:user_blue_level", + "needDiscrete": False, + "needWeighting": False, + "needKey": False, + "default_value": "0", + "combiner": "mean", + "need_prefix": False, + }, + { + "sequence_name": "click_100_seq", + "sequence_column": "click_100_seq", + "sequence_length": 100, + "sequence_delim": ";", + "attribute_delim": "#", + "sequence_table": "item", + "sequence_pk": "user:click_100_seq", + "features": [ + { + "feature_name": "item_id", + "feature_type": "id_feature", + "value_type": "String", + "expression": "item:item_id", + "default_value": "-1024", + "combiner": "mean", + "need_prefix": False, + "is_multi": False, + "group": "click_100_seq_feature", + }, + { + "feature_name": "ts", + "feature_type": "raw_feature", + "value_type": "Double", + "expression": "user:ts", + "default_value": "-1024", + "combiner": "mean", + "need_prefix": False, + "group": "click_100_seq_feature", + }, + ], + }, + ], + "reserves": ["is_click_cover", "is_click_video"], +} + +EASYREC_CONFIG = """train_config { + optimizer_config { + use_moving_average: false + adam_asyncw_optimizer { + weight_decay: 1e-6 + learning_rate { + constant_learning_rate { + learning_rate: 0.001 + } + } + } + } + sync_replicas: false + save_summary_steps: 1000 + log_step_count_steps: 100 + save_checkpoints_steps: 1000000 + keep_checkpoint_max: 1 +} +data_config { + batch_size: 4096 + label_fields: "is_click_cover" + label_fields: "is_click_video" + shuffle: false + num_epochs: 3 + input_type: OdpsRTPInput + separator: "" + selected_cols: "is_click_cover,is_click_video,features" + input_fields { + input_name: "is_click_cover" + input_type: INT32 + default_val: "0" + } + input_fields { + input_name: "is_click_video" + input_type: INT32 + default_val: "0" + } + input_fields { + input_name: "user_id" + input_type: STRING + default_val: "-1024" + } + input_fields { + input_name: "item_id" + input_type: STRING + default_val: "-1024" + } + input_fields { + input_name: "user_blue_level" + input_type: STRING + default_val: "-1024" + } + input_fields { + input_name: "host_price_level" + input_type: STRING + default_val: "-1024" + } + input_fields { + input_name: "user_video_sequence" + input_type: STRING + default_val: "-1024" + } + input_fields { + input_name: "item__kv_user_blue_level_exposure_cnt_7d" + input_type: DOUBLE + default_val: "0" + } + input_fields { + input_name: "item__kv_user_blue_level_click_focus_cnt_7d" + input_type: STRING + default_val: "" + } + input_fields { + input_name: "item__kv_user_blue_level_click_video_div_exposure_cnt_30d" + input_type: DOUBLE + default_val: "0" + } + input_fields { + input_name: "click_100_seq__item_id" + input_type: STRING + default_val: "" + } + input_fields { + input_name: "click_100_seq__ts" + input_type: STRING + default_val: "" + } + pai_worker_queue: true +} +feature_configs { + input_names: "user_id" + feature_type: IdFeature + embedding_dim: 4 + hash_bucket_size: 1000 + separator: "" + combiner: "mean" +} +feature_configs { + input_names: "item_id" + feature_type: IdFeature + embedding_dim: 24 + hash_bucket_size: 1500000 + separator: "" + combiner: "mean" +} +feature_configs { + input_names: "user_blue_level" + feature_type: IdFeature + embedding_dim: 4 + hash_bucket_size: 140 + separator: "" + combiner: "mean" +} +feature_configs { + input_names: "host_price_level" + feature_type: IdFeature + embedding_dim: 8 + hash_bucket_size: 180 + separator: "" + combiner: "mean" +} +feature_configs { + input_names: "user_video_sequence" + feature_type: SequenceFeature + embedding_dim: 24 + hash_bucket_size: 1500000 + separator: "," + combiner: "mean" + sub_feature_type: IdFeature +} + +feature_configs { + input_names: "item__kv_user_blue_level_exposure_cnt_7d" + feature_type: RawFeature + embedding_dim: 4 + boundaries: 1e-08 + boundaries: 47.00000001 + boundaries: 285.00000001 + boundaries: 672.00000001 + boundaries: 1186.00000001 + boundaries: 1853.00000001 + boundaries: 2716.00000001 + boundaries: 3861.00000001 + boundaries: 5459.00000001 + boundaries: 7817.00000001 + boundaries: 11722.0 + boundaries: 19513.0 + boundaries: 43334.0 + separator: "" +} +feature_configs { + feature_name: "item__kv_user_blue_level_click_focus_cnt_7d" + input_names: "item__kv_user_blue_level_click_focus_cnt_7d" + input_names: "user_blue_level" + feature_type: LookupFeature + embedding_dim: 4 + boundaries: 1e-08 + boundaries: 1.00000001 + boundaries: 2.00000001 + boundaries: 5.00000001 + boundaries: 8.00000001 + boundaries: 13.00000001 + boundaries: 19.00000001 + boundaries: 28.00000001 + boundaries: 42.00000001 + boundaries: 67.00000001 + boundaries: 123.00000001 + boundaries: 298.00000001 + separator: "" +} +feature_configs { + feature_name: "combo_user_blue_level_x_host_price_level" + input_names: "user_blue_level" + input_names: "host_price_level" + feature_type: ComboFeature + embedding_dim: 4 + hash_bucket_size: 140 +} +feature_configs { + input_names: "item__kv_user_blue_level_click_video_div_exposure_cnt_30d" + feature_type: RawFeature + separator: "" +} +feature_configs { + input_names: "click_100_seq__item_id" + feature_type: SequenceFeature + separator: ";" + combiner: "mean" + sub_feature_type: IdFeature + embedding_dim: 24 + hash_bucket_size: 1500000 +} +feature_configs { + input_names: "click_100_seq__ts" + feature_type: SequenceFeature + separator: ";" + combiner: "mean" + sub_feature_type: RawFeature + embedding_dim: 4 + boundaries: 1e-08 + boundaries: 3.00000001 + boundaries: 25.00000001 + boundaries: 70.00000001 +} + +model_config { + model_class: "DBMTL" + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "item_id" + feature_names: "user_blue_level" + feature_names: "host_price_level" + feature_names: "item__kv_user_blue_level_exposure_cnt_7d" + feature_names: "item__kv_user_blue_level_click_focus_cnt_7d" + feature_names: "combo_user_blue_level_x_host_price_level" + feature_names: "item__kv_user_blue_level_click_video_div_exposure_cnt_30d" + wide_deep: DEEP + sequence_features { + group_name: "seq_fea_1" + seq_att_map { + key: "item_id" + hist_seq: "user_video_sequence" + } + tf_summary: false + allow_key_search: false + } + sequence_features { + group_name: "seq_fea_2" + seq_att_map { + key: "item_id" + hist_seq: "click_100_seq__item_id" + aux_hist_seq: "click_100_seq__ts" + } + tf_summary: false + allow_key_search: false + } + } + dbmtl { + bottom_dnn { + hidden_units: 1024 + hidden_units: 512 + } + task_towers { + tower_name: "is_click_cover" + label_name: "is_click_cover" + metrics_set { + auc { + } + } + loss_type: CLASSIFICATION + dnn { + hidden_units: 256 + hidden_units: 128 + hidden_units: 64 + hidden_units: 32 + } + relation_dnn { + hidden_units: 32 + } + weight: 1.0 + } + task_towers { + tower_name: "is_click_video" + label_name: "is_click_video" + metrics_set { + auc { + } + } + loss_type: CLASSIFICATION + dnn { + hidden_units: 256 + hidden_units: 128 + hidden_units: 64 + hidden_units: 32 + } + relation_dnn { + hidden_units: 32 + } + weight: 2.0 + } + l2_regularization: 0 + } +} +export_config { + multi_placeholder: true +} +""" + +TRAIN_CONFIG = """train_config { + sparse_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + dense_optimizer { + adam_optimizer { + lr: 0.001 + } + constant_learning_rate { + } + } + num_epochs: 1 + use_tensorboard: false +} +""" + +DATA_CONFIG = """data_config { + batch_size: 4096 + dataset_type: OdpsDataset + fg_encoded: true + label_fields: "is_click_cover" + label_fields: "is_click_video" + num_workers: 8 + odps_data_quota_name: "" +} +""" + +FEATURE_CONFIG = """feature_configs { + id_feature { + feature_name: "user_id" + expression: "user:user_id" + embedding_dim: 4 + hash_bucket_size: 1000 + } +} +feature_configs { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + embedding_dim: 24 + hash_bucket_size: 1500000 + } +} +feature_configs { + id_feature { + feature_name: "user_blue_level" + expression: "user:user_blue_level" + embedding_dim: 4 + hash_bucket_size: 140 + } +} +feature_configs { + id_feature { + feature_name: "host_price_level" + expression: "item:host_price_level" + embedding_dim: 8 + hash_bucket_size: 180 + } +} +feature_configs { + sequence_id_feature { + feature_name: "user_video_sequence" + expression: "user:user_video_sequence" + embedding_dim: 24 + hash_bucket_size: 1500000 + } +} +feature_configs { + lookup_feature { + feature_name: "item__kv_user_blue_level_exposure_cnt_7d" + map: "item:item__kv_user_blue_level_exposure_cnt_7d" + key: "user:user_blue_level" + embedding_dim: 4 + boundaries: 1e-08 + boundaries: 47.0 + boundaries: 285.0 + boundaries: 672.0 + boundaries: 1186.0 + boundaries: 1853.0 + boundaries: 2716.0 + boundaries: 3861.0 + boundaries: 5459.0 + boundaries: 7817.0 + boundaries: 11722.0 + boundaries: 19513.0 + boundaries: 43334.0 + } +} +feature_configs { + lookup_feature { + feature_name: "item__kv_user_blue_level_click_focus_cnt_7d" + map: "item:item__kv_user_blue_level_click_focus_cnt_7d" + key: "user:user_blue_level" + embedding_dim: 4 + boundaries: 1e-08 + boundaries: 1.0 + boundaries: 2.0 + boundaries: 5.0 + boundaries: 8.0 + boundaries: 13.0 + boundaries: 19.0 + boundaries: 28.0 + boundaries: 42.0 + boundaries: 67.0 + boundaries: 123.0 + boundaries: 298.0 + } +} +feature_configs { + combo_feature { + feature_name: "combo_user_blue_level_x_host_price_level" + expression: "user:user_blue_level" + expression: "item:host_price_level" + embedding_dim: 4 + hash_bucket_size: 140 + } +} +feature_configs { + lookup_feature { + feature_name: "item__kv_user_blue_level_click_video_div_exposure_cnt_30d" + map: "item:item__kv_user_blue_level_click_video_div_exposure_cnt_30d" + key: "user:user_blue_level" + embedding_dim: 0 + } +} +feature_configs { + sequence_feature { + sequence_name: "click_100_seq" + sequence_length: 100 + sequence_delim: ";" + features { + id_feature { + feature_name: "item_id" + expression: "item:item_id" + embedding_dim: 24 + hash_bucket_size: 1500000 + } + } + features { + raw_feature { + feature_name: "ts" + expression: "user:ts" + embedding_dim: 4 + boundaries: 1e-08 + boundaries: 3.0 + boundaries: 25.0 + boundaries: 70.0 + } + } + } +} +""" + +MODEL_CONFIG = """model_config { + feature_groups { + group_name: "all" + feature_names: "user_id" + feature_names: "item_id" + feature_names: "user_blue_level" + feature_names: "host_price_level" + feature_names: "item__kv_user_blue_level_exposure_cnt_7d" + feature_names: "item__kv_user_blue_level_click_focus_cnt_7d" + feature_names: "combo_user_blue_level_x_host_price_level" + feature_names: "item__kv_user_blue_level_click_video_div_exposure_cnt_30d" + group_type: DEEP + sequence_groups { + group_name: "seq_fea_1" + feature_names: "item_id" + feature_names: "user_video_sequence" + } + sequence_groups { + group_name: "seq_fea_2" + feature_names: "item_id" + feature_names: "click_100_seq__item_id" + feature_names: "click_100_seq__ts" + } + sequence_encoders { + din_encoder { + input: "seq_fea_1" + attn_mlp { + use_bn: true + } + } + } + sequence_encoders { + din_encoder { + input: "seq_fea_2" + attn_mlp { + use_bn: true + } + } + } + } + dbmtl { + bottom_mlp { + hidden_units: 1024 + hidden_units: 512 + use_bn: true + } + expert_mlp { + use_bn: true + } + num_expert: 0 + task_towers { + tower_name: "is_click_cover" + label_name: "is_click_cover" + metrics { + auc { + } + } + num_class: 1 + mlp { + hidden_units: 256 + hidden_units: 128 + hidden_units: 64 + hidden_units: 32 + use_bn: true + } + relation_mlp { + hidden_units: 32 + use_bn: true + } + } + task_towers { + tower_name: "is_click_video" + label_name: "is_click_video" + metrics { + auc { + } + } + num_class: 1 + mlp { + hidden_units: 256 + hidden_units: 128 + hidden_units: 64 + hidden_units: 32 + use_bn: true + } + relation_mlp { + hidden_units: 32 + use_bn: true + } + } + } +} +""" + + +class ConvertConfigTest(unittest.TestCase): + def setUp(self): + self.success = False + if not os.path.exists("./tmp"): + os.makedirs("./tmp") + self.test_dir = tempfile.mkdtemp(prefix="tzrec_convert_", dir="./tmp") + self.fg_path = os.path.join(self.test_dir, "fg.json") + self.easyrec_path = os.path.join(self.test_dir, "easyrec.config") + self.tzrec_path = os.path.join(self.test_dir, "tzrec.config") + with open(self.easyrec_path, "w", encoding="utf-8") as f: + f.write(EASYREC_CONFIG) + f = open(self.fg_path, "w", encoding="utf-8") + json.dump(FG_JSON, f) + f.close() + + def tearDown(self): + if os.path.exists(self.test_dir): + shutil.rmtree(self.test_dir) + + def test_create_train_config(self): + convert = ConvertConfig(self.easyrec_path, self.fg_path, self.tzrec_path) + config = tzrec_pipeline_pb2.EasyRecConfig() + config = convert._create_train_config(config) + config_text = text_format.MessageToString(config, as_utf8=True) + self.assertEqual(config_text, TRAIN_CONFIG) + + def test_create_data_config(self): + convert = ConvertConfig(self.easyrec_path, self.fg_path, self.tzrec_path) + config = tzrec_pipeline_pb2.EasyRecConfig() + config = convert._create_data_config(config) + config_text = text_format.MessageToString(config, as_utf8=True) + self.assertEqual(config_text, DATA_CONFIG) + + def test_create_feature_config(self): + convert = ConvertConfig(self.easyrec_path, self.fg_path, self.tzrec_path) + config = tzrec_pipeline_pb2.EasyRecConfig() + config = convert._create_feature_config(config) + config_text = text_format.MessageToString(config, as_utf8=True) + self.assertEqual(config_text, FEATURE_CONFIG) + + def test_create_model_config(self): + convert = ConvertConfig(self.easyrec_path, self.fg_path, self.tzrec_path) + config = tzrec_pipeline_pb2.EasyRecConfig() + config = convert._create_model_config(config) + config_text = text_format.MessageToString(config, as_utf8=True) + self.assertEqual(config_text, MODEL_CONFIG) + + +if __name__ == "__main__": + unittest.main()