From dda0e4a5969f6ba47c3763aba9a55a1d06a36af7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Mon, 30 Sep 2024 17:00:00 +0800 Subject: [PATCH 1/9] add tdm serving tree & remove recall_num in gen tree --- docs/source/quick_start/local_tutorial_tdm.md | 8 ++-- tzrec/export.py | 6 +++ tzrec/main.py | 21 ++++++++-- tzrec/tests/train_eval_export_test.py | 10 ++++- tzrec/tests/utils.py | 22 +++++++---- tzrec/tools/tdm/cluster_tree.py | 11 +----- tzrec/tools/tdm/gen_tree/tree_search_util.py | 38 +++++++++---------- tzrec/tools/tdm/init_tree.py | 11 +----- tzrec/tools/tdm/retrieval.py | 18 +++++---- 9 files changed, 83 insertions(+), 62 deletions(-) diff --git a/docs/source/quick_start/local_tutorial_tdm.md b/docs/source/quick_start/local_tutorial_tdm.md index 5318f66..41589ad 100644 --- a/docs/source/quick_start/local_tutorial_tdm.md +++ b/docs/source/quick_start/local_tutorial_tdm.md @@ -52,7 +52,6 @@ python -m tzrec.tools.tdm.init_tree \ - --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 - --tree_output_file: (可选)初始树的保存路径, 不输入不会保存 - --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种 -- --recall_num: (可选,默认为200)召回数量, 会根据召回数量自动跳过前几层树, 增加召回的效率 - --n_cluster: (可选,默认为2)树的分叉数 #### 训练 @@ -80,11 +79,13 @@ torchrun --master_addr=localhost --master_port=32555 \ -m tzrec.export \ --pipeline_config_path experiments/tdm_taobao_local/pipeline.config \ --export_dir experiments/tdm_taobao_local/export + --asset_files data/init_tree/serving_tree ``` - --pipeline_config_path: 导出用的配置文件 - --checkpoint_path: 指定要导出的checkpoint, 默认评估model_dir下面最新的checkpoint - --export_dir: 导出到的模型目录 +- --asset_files: 需额拷贝到模型目录的文件。tdm需拷贝serving_tree树文件用于线上服务 #### 导出item embedding @@ -124,7 +125,6 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \ - --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 - --tree_output_file: (可选)树的保存路径, 不输入不会保存 - --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种 -- --recall_num: (可选,默认为200)召回数量, 会根据召回数量自动跳过前几层树, 增加召回的效率 - --n_cluster: (可选,默认为2)树的分叉数 - --parllel: (可选,默认为16)聚类时CPU并行数 @@ -153,11 +153,13 @@ torchrun --master_addr=localhost --master_port=32555 \ -m tzrec.export \ --pipeline_config_path experiments/tdm_taobao_local_learnt/pipeline.config \ --export_dir experiments/tdm_taobao_local_learnt/export + --asset_files data/learnt_tree/serving_tree ``` - --pipeline_config_path: 导出用的配置文件 - --checkpoint_path: 指定要导出的checkpoint, 默认评估model_dir下面最新的checkpoint - --export_dir: 导出到的模型目录 +- --asset_files: 需额拷贝到模型目录的文件。tdm需拷贝serving_tree树文件用于线上服务 #### Recall评估 @@ -181,7 +183,7 @@ torchrun --master_addr=localhost --master_port=32555 \ - --predict_input_path: 预测输入数据的路径 - --predict_output_path: 预测输出数据的路径 - --gt_item_id_field: 文件中代表真实点击item_id的列名 -- --recall_num:(可选, 默认为200) 召回的数量, 应与建树时输入保持一致 +- --recall_num:(可选, 默认为200) 召回的数量 - --n_cluster:(可选, 默认为2) 数的分叉数量, 应与建树时输入保持一致 - --reserved_columns: 预测结果中要保留的输入列 diff --git a/tzrec/export.py b/tzrec/export.py index be18f0f..4da2005 100644 --- a/tzrec/export.py +++ b/tzrec/export.py @@ -34,6 +34,12 @@ default=None, help="directory where model should be exported to.", ) + parser.add_argument( + "--asset_files", + type=str, + default=None, + help="more files will be copy to export_dir.", + ) args, extra_args = parser.parse_known_args() export( diff --git a/tzrec/main.py b/tzrec/main.py index f927f21..ed9b872 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -13,6 +13,7 @@ import itertools import json import os +import shutil from collections import OrderedDict from queue import Queue from threading import Thread @@ -747,7 +748,10 @@ def _script_model( def export( - pipeline_config_path: str, export_dir: str, checkpoint_path: Optional[str] = None + pipeline_config_path: str, + export_dir: str, + checkpoint_path: Optional[str] = None, + asset_files: Optional[str] = None, ) -> None: """Export a EasyRec model. @@ -756,6 +760,7 @@ def export( export_dir (str): base directory where the model should be exported. checkpoint_path (str, optional): if specified, will use this model instead of model specified by model_dir in pipeline_config_path. + asset_files (str, optional): more files will be copy to export_dir. """ pipeline_config = config_util.load_pipeline_config(pipeline_config_path) ori_pipeline_config = copy.copy(pipeline_config) @@ -766,6 +771,10 @@ def export( if os.path.exists(export_dir): raise RuntimeError(f"directory {export_dir} already exist.") + assets = [] + if asset_files: + assets = asset_files.split(",") + data_config = pipeline_config.data_config # Build feature features = _create_features(list(pipeline_config.feature_configs), data_config) @@ -832,13 +841,16 @@ def export( for name, module in cpu_model.named_children(): if isinstance(module, MatchTower): tower = ScriptWrapper(TowerWrapper(module, name)) + tower_export_dir = os.path.join(export_dir, name.replace("_tower", "")) _script_model( ori_pipeline_config, tower, cpu_state_dict, dataloader, - os.path.join(export_dir, name.replace("_tower", "")), + tower_export_dir, ) + for asset in assets: + shutil.copy(asset, tower_export_dir) elif isinstance(cpu_model, TDM): for name, module in cpu_model.named_children(): if isinstance(module, EmbeddingGroup): @@ -857,7 +869,8 @@ def export( dataloader, export_dir, ) - + for asset in assets: + shutil.copy(asset, export_dir) else: _script_model( ori_pipeline_config, @@ -866,6 +879,8 @@ def export( dataloader, export_dir, ) + for asset in assets: + shutil.copy(asset, export_dir) def predict( diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/train_eval_export_test.py index ed5ae5c..b551f17 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/train_eval_export_test.py @@ -537,7 +537,9 @@ def test_tdm_train_eval_export(self): ) if self.success: self.success = utils.test_export( - os.path.join(self.test_dir, "pipeline.config"), self.test_dir + os.path.join(self.test_dir, "pipeline.config"), + self.test_dir, + asset_files=os.path.join(self.test_dir, "init_tree/serving_tree"), ) if self.success: self.success = utils.test_predict( @@ -556,8 +558,9 @@ def test_tdm_train_eval_export(self): item_id="item_id", embedding_field="item_emb", ) + self.success = True if self.success: - with open(os.path.join(self.test_dir, "node_table.txt")) as f: + with open(os.path.join(self.test_dir, "init_tree/node_table.txt")) as f: for line_number, line in enumerate(f): if line_number == 1: root_id = int(line.split("\t")[0]) @@ -586,6 +589,9 @@ def test_tdm_train_eval_export(self): self.assertTrue( os.path.exists(os.path.join(self.test_dir, "export/scripted_model.pt")) ) + self.assertTrue( + os.path.exists(os.path.join(self.test_dir, "export/serving_tree")) + ) self.assertTrue(os.path.exists(os.path.join(self.test_dir, "retrieval_result"))) diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index c2ca33f..e23bd38 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -791,16 +791,19 @@ def load_config_for_test( f"--cate_id_field {cate_id} " f"--attr_fields {','.join(attr_fields)} " f"--raw_attr_fields {','.join(raw_attr_fields)} " - f"--node_edge_output_file {test_dir} " - f"--recall_num 1" + f"--node_edge_output_file {test_dir}/init_tree " ) p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_init_tree.txt")) p.wait(600) - sampler_config.item_input_path = os.path.join(test_dir, "node_table.txt") - sampler_config.edge_input_path = os.path.join(test_dir, "edge_table.txt") + sampler_config.item_input_path = os.path.join( + test_dir, "init_tree/node_table.txt" + ) + sampler_config.edge_input_path = os.path.join( + test_dir, "init_tree/edge_table.txt" + ) sampler_config.predict_edge_input_path = os.path.join( - test_dir, "predict_edge_table.txt" + test_dir, "init_tree/predict_edge_table.txt" ) else: @@ -874,7 +877,9 @@ def test_eval(pipeline_config_path: str, test_dir: str) -> bool: return True -def test_export(pipeline_config_path: str, test_dir: str) -> bool: +def test_export( + pipeline_config_path: str, test_dir: str, asset_files: str = "" +) -> bool: """Run export integration test.""" port = misc_util.get_free_port() log_dir = os.path.join(test_dir, "log_export") @@ -884,8 +889,10 @@ def test_export(pipeline_config_path: str, test_dir: str) -> bool: f"--nproc-per-node=2 --node_rank=0 --log_dir {log_dir} " "-r 3 -t 3 tzrec/export.py " f"--pipeline_config_path {pipeline_config_path} " - f"--export_dir {test_dir}/export" + f"--export_dir {test_dir}/export " ) + if asset_files: + cmd_str += f"--asset_files {asset_files}" p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_export.txt")) p.wait(600) @@ -1206,7 +1213,6 @@ def test_tdm_cluster_train_eval( f"--raw_attr_fields {','.join(raw_attr_fields)} " f"--node_edge_output_file {os.path.join(test_dir, 'learnt_tree')} " f"--parallel 1 " - f"--recall_num 1 " ) p = misc_util.run_cmd( cluster_cmd_str, os.path.join(test_dir, "log_tdm_cluster.txt") diff --git a/tzrec/tools/tdm/cluster_tree.py b/tzrec/tools/tdm/cluster_tree.py index dedfc5a..705f104 100644 --- a/tzrec/tools/tdm/cluster_tree.py +++ b/tzrec/tools/tdm/cluster_tree.py @@ -10,7 +10,6 @@ # limitations under the License. import argparse -import math from tzrec.tools.tdm.gen_tree.tree_cluster import TreeCluster from tzrec.tools.tdm.gen_tree.tree_search_util import TreeSearch @@ -66,12 +65,6 @@ default=16, help="The number of CPU cores for parallel processing.", ) - parser.add_argument( - "--recall_num", - type=int, - default=200, - help="Recall number per item when retrieval.", - ) parser.add_argument( "--n_cluster", type=int, @@ -102,6 +95,6 @@ child_num=args.n_cluster, ) tree_search.save() - first_recall_layer = int(math.ceil(math.log(2 * args.recall_num, args.n_cluster))) - tree_search.save_predict_edge(first_recall_layer) + tree_search.save_predict_edge() + tree_search.save_serving_tree() logger.info("Save nodes and edges table done.") diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util.py b/tzrec/tools/tdm/gen_tree/tree_search_util.py index 8347eb9..a043dfd 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util.py @@ -167,25 +167,19 @@ def save(self) -> None: for i in range(self.max_level): f.write(f"{travel[0]}\t{travel[i+1]}\t{1.0}\n") - def save_predict_edge(self, first_recall_layer: int) -> None: + def save_predict_edge(self) -> None: """Save edge info for prediction.""" if self.output_file.startswith("odps://"): writer = create_writer(self.output_file + "predict_edge_table") src_ids = [] dst_ids = [] weight = [] - for i in range(first_recall_layer - 1, self.max_level): - if i == first_recall_layer - 1: - for node in self.level_code[i + 1]: - src_ids.append(self.root.item_id) - dst_ids.append(node.item_id) + for i in range(self.max_level): + for node in self.level_code[i]: + for child in node.children: + src_ids.append(node.item_id) + dst_ids.append(child.item_id) weight.append(1.0) - else: - for node in self.level_code[i]: - for child in node.children: - src_ids.append(node.item_id) - dst_ids.append(child.item_id) - weight.append(1.0) edge_table_dict = OrderedDict() edge_table_dict["src_id"] = pa.array(src_ids) edge_table_dict["dst_id"] = pa.array(dst_ids) @@ -196,11 +190,15 @@ def save_predict_edge(self, first_recall_layer: int) -> None: os.path.join(self.output_file, "predict_edge_table.txt"), "w" ) as f: f.write("src_id:int64\tdst_id:int64\tweight:float\n") - for i in range(first_recall_layer - 1, self.max_level): - if i == first_recall_layer - 1: - for node in self.level_code[i + 1]: - f.write(f"{self.root.item_id}\t{node.item_id}\t{1.0}\n") - else: - for node in self.level_code[i]: - for child in node.children: - f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n") + for i in range(self.max_level): + for node in self.level_code[i]: + for child in node.children: + f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n") + + def save_serving_tree(self) -> None: + """Save tree info for serving.""" + with open(os.path.join(self.output_file, "serving_tree"), "w") as f: + f.write(f"{self.max_level + 1} {self.child_num}\n") + for _, nodes in enumerate(self.level_code): + for node in nodes: + f.write(f"{node.tree_code} {node.item_id}\n") diff --git a/tzrec/tools/tdm/init_tree.py b/tzrec/tools/tdm/init_tree.py index 12d75d7..b2d8c86 100644 --- a/tzrec/tools/tdm/init_tree.py +++ b/tzrec/tools/tdm/init_tree.py @@ -10,7 +10,6 @@ # limitations under the License. import argparse -import math from tzrec.tools.tdm.gen_tree.tree_generator import TreeGenerator from tzrec.tools.tdm.gen_tree.tree_search_util import TreeSearch @@ -60,12 +59,6 @@ default=None, help="The nodes and edges table output file.", ) - parser.add_argument( - "--recall_num", - type=int, - default=200, - help="Recall number per item when retrieval.", - ) parser.add_argument( "--n_cluster", type=int, @@ -95,6 +88,6 @@ child_num=args.n_cluster, ) tree_search.save() - first_recall_layer = int(math.ceil(math.log(2 * args.recall_num, args.n_cluster))) - tree_search.save_predict_edge(first_recall_layer) + tree_search.save_predict_edge() + tree_search.save_serving_tree() logger.info("Save nodes and edges table done.") diff --git a/tzrec/tools/tdm/retrieval.py b/tzrec/tools/tdm/retrieval.py index 3b3f782..40f810a 100644 --- a/tzrec/tools/tdm/retrieval.py +++ b/tzrec/tools/tdm/retrieval.py @@ -198,7 +198,7 @@ def tdm_retrieval( sampler_config = pipeline_config.data_config.tdm_sampler item_id_field = sampler_config.item_id_field max_level = len(sampler_config.layer_num_sample) - first_recall_layer = int(math.ceil(math.log(2 * recall_num, n_cluster))) + first_recall_layer = int(math.ceil(math.log(2 * n_cluster * recall_num, n_cluster))) dataset = infer_dataloader.dataset # pyre-ignore [16] @@ -210,6 +210,7 @@ def tdm_retrieval( pos_sampler.init_cluster(num_client_per_rank=1) pos_sampler.launch_server() pos_sampler.init() + pos_sampler.init_sampler(n_cluster) i_step = 0 num_class = pipeline_config.model_config.num_class @@ -226,10 +227,14 @@ def tdm_retrieval( cur_batch_size = len(node_ids) expand_num = n_cluster**first_recall_layer - pos_sampler.init_sampler(expand_num) - - for layer in range(first_recall_layer, max_level): + for layer in range(1, max_level): sampled_result_dict = pos_sampler.get(node_ids) + + # skip layers before first_recall_layer + if layer < first_recall_layer: + node_ids = sampled_result_dict[item_id_field] + continue + updated_inputs = update_data( reserve_batch_record, sampled_result_dict, expand_num ) @@ -267,16 +272,13 @@ def tdm_retrieval( _, topk_indices_in_group = torch.topk(probs, k, dim=1) topk_indices = ( topk_indices_in_group - + torch.arange(cur_batch_size) - .unsqueeze(1) - .to(topk_indices_in_group.device) + + torch.arange(cur_batch_size, device=device).unsqueeze(1) * expand_num ) topk_indices = topk_indices.reshape(-1).cpu().numpy() node_ids = updated_inputs[item_id_field].take(topk_indices) if layer == first_recall_layer: - pos_sampler.init_sampler(n_cluster) expand_num = n_cluster * k output_dict = OrderedDict() From 8a95f19fbfac84c50544cf35ea0b0c2320a09518 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Mon, 30 Sep 2024 17:10:01 +0800 Subject: [PATCH 2/9] add asset_files to export entry --- tzrec/export.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tzrec/export.py b/tzrec/export.py index 4da2005..8a5cf0b 100644 --- a/tzrec/export.py +++ b/tzrec/export.py @@ -46,4 +46,5 @@ args.pipeline_config_path, export_dir=args.export_dir, checkpoint_path=args.checkpoint_path, + asset_files=args.asset_files, ) From 8ff0836bf0b735da44de98952894ae0cd7f54aa1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Mon, 30 Sep 2024 17:13:53 +0800 Subject: [PATCH 3/9] remove success=True in test --- tzrec/tests/train_eval_export_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tzrec/tests/train_eval_export_test.py b/tzrec/tests/train_eval_export_test.py index b551f17..639f28a 100644 --- a/tzrec/tests/train_eval_export_test.py +++ b/tzrec/tests/train_eval_export_test.py @@ -558,7 +558,6 @@ def test_tdm_train_eval_export(self): item_id="item_id", embedding_field="item_emb", ) - self.success = True if self.success: with open(os.path.join(self.test_dir, "init_tree/node_table.txt")) as f: for line_number, line in enumerate(f): From 805388257f456c13e8189666c4917f778b86b5cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Mon, 30 Sep 2024 17:32:15 +0800 Subject: [PATCH 4/9] fix tree search test --- tzrec/tools/tdm/gen_tree/tree_search_util_test.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py index d22eda7..1204e98 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py @@ -51,11 +51,13 @@ def test_cluster(self) -> None: root = cluster.train(save_tree=False) search = TreeSearch(output_file=self.test_dir, root=root, child_num=2) search.save() - search.save_predict_edge(3) + search.save_predict_edge() + search.save_serving_tree() node_table = [] edge_table = [] predict_edge_table = [] + serving_tree = [] with open(os.path.join(self.test_dir, "node_table.txt")) as f: for line in f: node_table.append(line) @@ -63,14 +65,17 @@ def test_cluster(self) -> None: with open(os.path.join(self.test_dir, "edge_table.txt")) as f: for line in f: edge_table.append(line) - with open(os.path.join(self.test_dir, "predict_edge_table.txt")) as f: for line in f: predict_edge_table.append(line) + with open(os.path.join(self.test_dir, "serving_tree")) as f: + for line in f: + serving_tree.append(line) self.assertEqual(len(node_table), 14) self.assertEqual(len(edge_table), 19) - self.assertEqual(len(predict_edge_table), 7) + self.assertEqual(len(predict_edge_table), 13) + self.assertEqual(len(serving_tree), 14) if __name__ == "__main__": From 3146bf051597d9373d024fa0099160b57d040d2b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Mon, 30 Sep 2024 17:39:32 +0800 Subject: [PATCH 5/9] add failed exit 1 to test runner --- tzrec/tests/run.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tzrec/tests/run.py b/tzrec/tests/run.py index bc70a19..cef4844 100644 --- a/tzrec/tests/run.py +++ b/tzrec/tests/run.py @@ -12,6 +12,7 @@ import argparse import os +import sys import unittest @@ -48,4 +49,7 @@ def _gather_test_cases(args): runner = unittest.TextTestRunner() test_suite = _gather_test_cases(args) if not args.list_tests: - runner.run(test_suite) + result = runner.run(test_suite) + failed, errored = len(result.failures), len(result.errors) + if failed > 0 or errored > 0: + sys.exit(1) From ec8ef1d2b87beccd0c1bf465117b93ea11135167 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Fri, 4 Oct 2024 11:11:48 +0800 Subject: [PATCH 6/9] make serving tree use tree_output_dir as directory, node_edge_output_file may be odps table --- docs/source/quick_start/local_tutorial_tdm.md | 12 ++++++++---- tzrec/tests/utils.py | 2 ++ tzrec/tools/tdm/cluster_tree.py | 13 +++++-------- tzrec/tools/tdm/gen_tree/tree_cluster.py | 18 +++++++++--------- tzrec/tools/tdm/gen_tree/tree_generator.py | 6 +++--- tzrec/tools/tdm/gen_tree/tree_search_util.py | 6 ++++-- .../tdm/gen_tree/tree_search_util_test.py | 2 +- tzrec/tools/tdm/init_tree.py | 15 ++++++--------- 8 files changed, 38 insertions(+), 36 deletions(-) diff --git a/docs/source/quick_start/local_tutorial_tdm.md b/docs/source/quick_start/local_tutorial_tdm.md index 41589ad..6b71bb8 100644 --- a/docs/source/quick_start/local_tutorial_tdm.md +++ b/docs/source/quick_start/local_tutorial_tdm.md @@ -43,6 +43,7 @@ python -m tzrec.tools.tdm.init_tree \ --cate_id_field cate_id \ --attr_fields cate_id,campaign_id,customer,brand,price \ --node_edge_output_file data/init_tree +--tree_output_dir data/init_tree ``` - --item_input_path: 建树用的item特征文件 @@ -50,8 +51,10 @@ python -m tzrec.tools.tdm.init_tree \ - --cate_id_field: 代表item的类别的列名 - --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 - --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 -- --tree_output_file: (可选)初始树的保存路径, 不输入不会保存 -- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种 +- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表`和`本地txt GL`两种 + - ODPS GL表:设置形如`odps://{project}/tables/{tb_prefix}`,将会产出用于TDM训练负采样的GL Node表`odps://{project}/tables/{tb_prefix}_node_table`、GL Edge表`odps://{project}/tables/{tb_prefix}_edge_table`、用于离线检索的GL Edge表`odps://{project}/tables/{tb_prefix}_predict_edge_table` + - 本地txt GL表:设置的为目录, 将在目录下产出用于TDM训练负采样的GL Node表`node_table.txt`,GL Edge表`edge_table.txt`、用于离线检索的GL Edge表`predict_edge_table.txt` +- --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务 - --n_cluster: (可选,默认为2)树的分叉数 #### 训练 @@ -115,6 +118,7 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \ --embedding_field item_emb \ --attr_fields cate_id,campaign_id,customer,brand,price \ --node_edge_output_file data/learnt_tree \ + --tree_output_dir data/learnt_tree \ --parallel 16 ``` @@ -123,8 +127,8 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \ - --embedding_field: 代表item embedding的列名 - --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 - --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 -- --tree_output_file: (可选)树的保存路径, 不输入不会保存 -- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持ODPS和本地txt两种 +- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表`和`本地txt GL`两种,同初始树 +- --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务 - --n_cluster: (可选,默认为2)树的分叉数 - --parllel: (可选,默认为16)聚类时CPU并行数 diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index e23bd38..a70b12b 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -792,6 +792,7 @@ def load_config_for_test( f"--attr_fields {','.join(attr_fields)} " f"--raw_attr_fields {','.join(raw_attr_fields)} " f"--node_edge_output_file {test_dir}/init_tree " + f"--tree_output_dir {test_dir}/init_tree " ) p = misc_util.run_cmd(cmd_str, os.path.join(test_dir, "log_init_tree.txt")) p.wait(600) @@ -1212,6 +1213,7 @@ def test_tdm_cluster_train_eval( f"--attr_fields {','.join(attr_fields)} " f"--raw_attr_fields {','.join(raw_attr_fields)} " f"--node_edge_output_file {os.path.join(test_dir, 'learnt_tree')} " + f"--tree_output_dir {os.path.join(test_dir, 'learnt_tree')} " f"--parallel 1 " ) p = misc_util.run_cmd( diff --git a/tzrec/tools/tdm/cluster_tree.py b/tzrec/tools/tdm/cluster_tree.py index 705f104..c689d0c 100644 --- a/tzrec/tools/tdm/cluster_tree.py +++ b/tzrec/tools/tdm/cluster_tree.py @@ -48,7 +48,7 @@ help="The column names representing the raw features of item in the file.", ) parser.add_argument( - "--tree_output_file", + "--tree_output_dir", type=str, default=None, help="The tree output file.", @@ -78,16 +78,12 @@ item_id_field=args.item_id_field, attr_fields=args.attr_fields, raw_attr_fields=args.raw_attr_fields, - output_file=args.tree_output_file, + output_dir=args.tree_output_dir, embedding_field=args.embedding_field, parallel=args.parallel, n_cluster=args.n_cluster, ) - if args.tree_output_file: - save_tree = True - else: - save_tree = False - root = cluster.train(save_tree) + root = cluster.train() logger.info("Tree cluster done. Start save nodes and edges table.") tree_search = TreeSearch( output_file=args.node_edge_output_file, @@ -96,5 +92,6 @@ ) tree_search.save() tree_search.save_predict_edge() - tree_search.save_serving_tree() + if args.tree_output_dir: + tree_search.save_serving_tree(args.tree_output_dir) logger.info("Save nodes and edges table done.") diff --git a/tzrec/tools/tdm/gen_tree/tree_cluster.py b/tzrec/tools/tdm/gen_tree/tree_cluster.py index e43fe17..cccabc6 100644 --- a/tzrec/tools/tdm/gen_tree/tree_cluster.py +++ b/tzrec/tools/tdm/gen_tree/tree_cluster.py @@ -31,12 +31,12 @@ class TreeCluster: """Cluster based on emb vec. Args: - item_input_path(str): The file path where the item information is stored. - item_id_field(str): The column name representing item_id in the file. - attr_fields(List[str]): The column names representing the features in the file. - output_file(str): The output file. - parallel(int): The number of CPU cores for parallel processing. - n_cluster(int): The branching factor of the nodes in the tree. + item_input_path (str): The file path where the item information is stored. + item_id_field (str): The column name representing item_id in the file. + attr_fields (List[str]): The column names representing the features in the file. + output_dir (str): The output file. + parallel (int): The number of CPU cores for parallel processing. + n_cluster (int): The branching factor of the nodes in the tree. """ def __init__( @@ -45,7 +45,7 @@ def __init__( item_id_field: str, attr_fields: Optional[str] = None, raw_attr_fields: Optional[str] = None, - output_file: Optional[str] = None, + output_dir: Optional[str] = None, embedding_field: str = "item_emb", parallel: int = 16, n_cluster: int = 2, @@ -60,7 +60,7 @@ def __init__( self.queue = None self.timeout = 5 self.codes = None - self.output_file = output_file + self.output_dir = output_dir self.n_clusters = n_cluster self.item_id_field = item_id_field @@ -140,7 +140,7 @@ def train(self, save_tree: bool = False) -> TDMTreeClass: p.join() assert queue.empty() - builder = tree_builder.TreeBuilder(self.output_file, self.n_clusters) + builder = tree_builder.TreeBuilder(self.output_dir, self.n_clusters) root = builder.build( self.ids, self.codes, self.attrs, self.raw_attrs, self.data, save_tree ) diff --git a/tzrec/tools/tdm/gen_tree/tree_generator.py b/tzrec/tools/tdm/gen_tree/tree_generator.py index a4ec376..83eef19 100644 --- a/tzrec/tools/tdm/gen_tree/tree_generator.py +++ b/tzrec/tools/tdm/gen_tree/tree_generator.py @@ -37,7 +37,7 @@ def __init__( cate_id_field: str, attr_fields: Optional[str] = None, raw_attr_fields: Optional[str] = None, - tree_output_file: Optional[str] = None, + tree_output_dir: Optional[str] = None, n_cluster: int = 2, ) -> None: self.item_input_path = item_input_path @@ -49,7 +49,7 @@ def __init__( self.attr_fields = [x.strip() for x in attr_fields.split(",")] if raw_attr_fields: self.raw_attr_fields = [x.strip() for x in raw_attr_fields.split(",")] - self.tree_output_file = tree_output_file + self.tree_output_dir = tree_output_dir self.n_cluster = n_cluster def generate(self, save_tree: bool = False) -> TDMTreeClass: @@ -141,6 +141,6 @@ def gen_code(start: int, end: int, code: int, items: List[Item]) -> None: ) data = np.array([[] for i in range(len(ids))]) - builder = TreeBuilder(self.tree_output_file, self.n_cluster) + builder = TreeBuilder(self.tree_output_dir, self.n_cluster) root = builder.build(ids, codes, attrs, raw_attrs, data, save_tree) return root diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util.py b/tzrec/tools/tdm/gen_tree/tree_search_util.py index a043dfd..5648af3 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util.py @@ -195,9 +195,11 @@ def save_predict_edge(self) -> None: for child in node.children: f.write(f"{node.item_id}\t{child.item_id}\t{1.0}\n") - def save_serving_tree(self) -> None: + def save_serving_tree(self, tree_output_dir: str) -> None: """Save tree info for serving.""" - with open(os.path.join(self.output_file, "serving_tree"), "w") as f: + if not os.path.exists(tree_output_dir): + os.makedirs(tree_output_dir) + with open(os.path.join(tree_output_dir, "serving_tree"), "w") as f: f.write(f"{self.max_level + 1} {self.child_num}\n") for _, nodes in enumerate(self.level_code): for node in nodes: diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py index 1204e98..b5639f7 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py @@ -52,7 +52,7 @@ def test_cluster(self) -> None: search = TreeSearch(output_file=self.test_dir, root=root, child_num=2) search.save() search.save_predict_edge() - search.save_serving_tree() + search.save_serving_tree(self.test_dir) node_table = [] edge_table = [] diff --git a/tzrec/tools/tdm/init_tree.py b/tzrec/tools/tdm/init_tree.py index b2d8c86..d6d3954 100644 --- a/tzrec/tools/tdm/init_tree.py +++ b/tzrec/tools/tdm/init_tree.py @@ -48,10 +48,10 @@ help="The column names representing the raw features of item in the file.", ) parser.add_argument( - "--tree_output_file", + "--tree_output_dir", type=str, default=None, - help="The tree output file.", + help="The tree output directory.", ) parser.add_argument( "--node_edge_output_file", @@ -73,14 +73,10 @@ cate_id_field=args.cate_id_field, attr_fields=args.attr_fields, raw_attr_fields=args.raw_attr_fields, - tree_output_file=args.tree_output_file, + tree_output_dir=args.tree_output_dir, n_cluster=args.n_cluster, ) - if args.tree_output_file: - save_tree = True - else: - save_tree = False - root = generator.generate(save_tree) + root = generator.generate() logger.info("Tree init done. Start save nodes and edges table.") tree_search = TreeSearch( output_file=args.node_edge_output_file, @@ -89,5 +85,6 @@ ) tree_search.save() tree_search.save_predict_edge() - tree_search.save_serving_tree() + if args.tree_output_dir: + tree_search.save_serving_tree(args.tree_output_dir) logger.info("Save nodes and edges table done.") From 6172e5c3f8d2859e6c5fd0c03b9807251906799a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Sat, 5 Oct 2024 10:07:59 +0800 Subject: [PATCH 7/9] fix tests --- tzrec/tools/tdm/gen_tree/tree_cluster_test.py | 2 +- tzrec/tools/tdm/gen_tree/tree_search_util_test.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tzrec/tools/tdm/gen_tree/tree_cluster_test.py b/tzrec/tools/tdm/gen_tree/tree_cluster_test.py index 4383845..57e5751 100644 --- a/tzrec/tools/tdm/gen_tree/tree_cluster_test.py +++ b/tzrec/tools/tdm/gen_tree/tree_cluster_test.py @@ -39,7 +39,7 @@ def test_cluster(self) -> None: embedding_field="item_emb", attr_fields="cate_id,str_a", raw_attr_fields="raw_1", - output_file=None, + output_dir=None, parallel=1, n_cluster=2, ) diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py index b5639f7..ebe276b 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py @@ -36,13 +36,13 @@ def setUp(self) -> None: def tearDown(self) -> None: shutil.rmtree(self.test_dir) - def test_cluster(self) -> None: + def test_tree_search(self) -> None: cluster = TreeCluster( item_input_path=self.tmp_file.name, item_id_field="item_id", attr_fields="int_a,str_c", raw_attr_fields="float_b", - output_file=None, + output_dir=None, embedding_field="item_emb", parallel=1, n_cluster=2, From 046b3c5ebcd66bb444a58a14bc62fe5f34a5b58b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 8 Oct 2024 10:19:09 +0800 Subject: [PATCH 8/9] fix typo --- tzrec/export.py | 2 +- tzrec/main.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tzrec/export.py b/tzrec/export.py index 8a5cf0b..8b3e535 100644 --- a/tzrec/export.py +++ b/tzrec/export.py @@ -38,7 +38,7 @@ "--asset_files", type=str, default=None, - help="more files will be copy to export_dir.", + help="more files will be copied to export_dir.", ) args, extra_args = parser.parse_known_args() diff --git a/tzrec/main.py b/tzrec/main.py index ed9b872..85e3d81 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -760,7 +760,7 @@ def export( export_dir (str): base directory where the model should be exported. checkpoint_path (str, optional): if specified, will use this model instead of model specified by model_dir in pipeline_config_path. - asset_files (str, optional): more files will be copy to export_dir. + asset_files (str, optional): more files will be copied to export_dir. """ pipeline_config = config_util.load_pipeline_config(pipeline_config_path) ori_pipeline_config = copy.copy(pipeline_config) From 0e15112ea467c2a66d5c045fbf3d172e4568f2ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Tue, 8 Oct 2024 10:20:25 +0800 Subject: [PATCH 9/9] remove gl doc --- docs/source/quick_start/local_tutorial_tdm.md | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/docs/source/quick_start/local_tutorial_tdm.md b/docs/source/quick_start/local_tutorial_tdm.md index 6b71bb8..037bf5d 100644 --- a/docs/source/quick_start/local_tutorial_tdm.md +++ b/docs/source/quick_start/local_tutorial_tdm.md @@ -51,9 +51,9 @@ python -m tzrec.tools.tdm.init_tree \ - --cate_id_field: 代表item的类别的列名 - --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 - --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 -- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表`和`本地txt GL`两种 - - ODPS GL表:设置形如`odps://{project}/tables/{tb_prefix}`,将会产出用于TDM训练负采样的GL Node表`odps://{project}/tables/{tb_prefix}_node_table`、GL Edge表`odps://{project}/tables/{tb_prefix}_edge_table`、用于离线检索的GL Edge表`odps://{project}/tables/{tb_prefix}_predict_edge_table` - - 本地txt GL表:设置的为目录, 将在目录下产出用于TDM训练负采样的GL Node表`node_table.txt`,GL Edge表`edge_table.txt`、用于离线检索的GL Edge表`predict_edge_table.txt` +- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS表`和`本地txt`两种 + - ODPS表:设置形如`odps://{project}/tables/{tb_prefix}`,将会产出用于TDM训练负采样的GL Node表`odps://{project}/tables/{tb_prefix}_node_table`、GL Edge表`odps://{project}/tables/{tb_prefix}_edge_table`、用于离线检索的GL Edge表`odps://{project}/tables/{tb_prefix}_predict_edge_table` + - 本地txt:设置的为目录, 将在目录下产出用于TDM训练负采样的GL Node表`node_table.txt`,GL Edge表`edge_table.txt`、用于离线检索的GL Edge表`predict_edge_table.txt` - --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务 - --n_cluster: (可选,默认为2)树的分叉数 @@ -127,7 +127,7 @@ OMP_NUM_THREADS=4 python tzrec/tools/tdm/cluster_tree.py \ - --embedding_field: 代表item embedding的列名 - --attr_fields: (可选) 除了item_id外的item非数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 - --raw_attr_fields: (可选) item的数值型特征列名, 用逗号分开. 注意和配置文件中tdm_sampler顺序一致 -- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS GL表`和`本地txt GL`两种,同初始树 +- --node_edge_output_file: 根据树生成的node和edge表的保存路径, 支持`ODPS表`和`本地txt`两种,同初始树 - --tree_output_dir: (可选) 树的保存目录, 将会在目录下存储`serving_tree`文件用于线上服务 - --n_cluster: (可选,默认为2)树的分叉数 - --parllel: (可选,默认为16)聚类时CPU并行数