From 805388257f456c13e8189666c4917f778b86b5cd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Mon, 30 Sep 2024 17:32:15 +0800 Subject: [PATCH] fix tree search test --- tzrec/tools/tdm/gen_tree/tree_search_util_test.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py index d22eda7..1204e98 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util_test.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util_test.py @@ -51,11 +51,13 @@ def test_cluster(self) -> None: root = cluster.train(save_tree=False) search = TreeSearch(output_file=self.test_dir, root=root, child_num=2) search.save() - search.save_predict_edge(3) + search.save_predict_edge() + search.save_serving_tree() node_table = [] edge_table = [] predict_edge_table = [] + serving_tree = [] with open(os.path.join(self.test_dir, "node_table.txt")) as f: for line in f: node_table.append(line) @@ -63,14 +65,17 @@ def test_cluster(self) -> None: with open(os.path.join(self.test_dir, "edge_table.txt")) as f: for line in f: edge_table.append(line) - with open(os.path.join(self.test_dir, "predict_edge_table.txt")) as f: for line in f: predict_edge_table.append(line) + with open(os.path.join(self.test_dir, "serving_tree")) as f: + for line in f: + serving_tree.append(line) self.assertEqual(len(node_table), 14) self.assertEqual(len(edge_table), 19) - self.assertEqual(len(predict_edge_table), 7) + self.assertEqual(len(predict_edge_table), 13) + self.assertEqual(len(serving_tree), 14) if __name__ == "__main__":