diff --git a/tzrec/tools/tdm/cluster_tree.py b/tzrec/tools/tdm/cluster_tree.py index c689d0c..b0d7145 100644 --- a/tzrec/tools/tdm/cluster_tree.py +++ b/tzrec/tools/tdm/cluster_tree.py @@ -71,6 +71,12 @@ default=2, help="The branching factor of the nodes in the tree.", ) + parser.add_argument( + "--odps_data_quota_name", + type=str, + default="pay-as-you-go", + help="maxcompute storage api/tunnel data quota name.", + ) args, extra_args = parser.parse_known_args() cluster = TreeCluster( @@ -82,6 +88,7 @@ embedding_field=args.embedding_field, parallel=args.parallel, n_cluster=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) root = cluster.train() logger.info("Tree cluster done. Start save nodes and edges table.") @@ -89,6 +96,7 @@ output_file=args.node_edge_output_file, root=root, child_num=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) tree_search.save() tree_search.save_predict_edge() diff --git a/tzrec/tools/tdm/gen_tree/tree_cluster.py b/tzrec/tools/tdm/gen_tree/tree_cluster.py index cccabc6..83ceef5 100644 --- a/tzrec/tools/tdm/gen_tree/tree_cluster.py +++ b/tzrec/tools/tdm/gen_tree/tree_cluster.py @@ -14,7 +14,7 @@ import os import time from multiprocessing.connection import Connection -from typing import List, Optional +from typing import Any, List, Optional import numpy as np import numpy.typing as npt @@ -49,6 +49,7 @@ def __init__( embedding_field: str = "item_emb", parallel: int = 16, n_cluster: int = 2, + **kwargs: Any, ) -> None: self.item_input_path = item_input_path self.mini_batch = 1024 @@ -73,15 +74,19 @@ def __init__( self.embedding_field = embedding_field + self.dataset_kwargs = {} + if "odps_data_quota_name" in kwargs: + self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] + def _read(self) -> None: t1 = time.time() ids = list() data = list() attrs = list() raw_attrs = list() - reader = create_reader(self.item_input_path, 256) + reader = create_reader(self.item_input_path, 256, **self.dataset_kwargs) for data_dict in reader.to_batches(): - ids += data_dict[self.item_id_field].to_pylist() + ids += data_dict[self.item_id_field].cast(pa.int64()).to_pylist() data += data_dict[self.embedding_field].to_pylist() tmp_attr = [] if self.attr_fields is not None: diff --git a/tzrec/tools/tdm/gen_tree/tree_generator.py b/tzrec/tools/tdm/gen_tree/tree_generator.py index 83eef19..c61be7c 100644 --- a/tzrec/tools/tdm/gen_tree/tree_generator.py +++ b/tzrec/tools/tdm/gen_tree/tree_generator.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import pyarrow as pa @@ -39,6 +39,7 @@ def __init__( raw_attr_fields: Optional[str] = None, tree_output_dir: Optional[str] = None, n_cluster: int = 2, + **kwargs: Any, ) -> None: self.item_input_path = item_input_path self.item_id_field = item_id_field @@ -52,6 +53,10 @@ def __init__( self.tree_output_dir = tree_output_dir self.n_cluster = n_cluster + self.dataset_kwargs = {} + if "odps_data_quota_name" in kwargs: + self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] + def generate(self, save_tree: bool = False) -> TDMTreeClass: """Generate tree.""" item_fea = self._read() @@ -60,9 +65,11 @@ def generate(self, save_tree: bool = False) -> TDMTreeClass: def _read(self) -> Dict[str, List[Union[int, float, str]]]: item_fea = {"ids": [], "cates": [], "attrs": [], "raw_attrs": []} - reader = create_reader(self.item_input_path, 4096) + reader = create_reader(self.item_input_path, 4096, **self.dataset_kwargs) for data_dict in reader.to_batches(): - item_fea["ids"] += data_dict[self.item_id_field].to_pylist() + item_fea["ids"] += ( + data_dict[self.item_id_field].cast(pa.int64()).to_pylist() + ) item_fea["cates"] += ( data_dict[self.cate_id_field] .cast(pa.string()) diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util.py b/tzrec/tools/tdm/gen_tree/tree_search_util.py index 5648af3..ce9984a 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util.py @@ -12,7 +12,7 @@ import os import pickle from collections import OrderedDict -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Any, Callable, Iterator, List, Optional, Tuple import pyarrow as pa from anytree.importer.dictimporter import DictImporter @@ -62,6 +62,7 @@ def __init__( tree_path: Optional[str] = None, root: Optional[TDMTreeClass] = None, child_num: int = 2, + **kwargs: Any, ) -> None: self.child_num = child_num if root is not None: @@ -81,6 +82,10 @@ def __init__( self.output_file = output_file + self.dataset_kwargs = {} + if "odps_data_quota_name" in kwargs: + self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] + self._get_nodes() def _load(self, path: str) -> None: @@ -111,7 +116,9 @@ def _get_nodes(self) -> None: def save(self) -> None: """Save tree info.""" if self.output_file.startswith("odps://"): - node_writer = create_writer(self.output_file + "node_table") + node_writer = create_writer( + self.output_file + "node_table", **self.dataset_kwargs + ) ids = [] weight = [] features = [] @@ -131,7 +138,9 @@ def save(self) -> None: node_table_dict["features"] = pa.array(features) node_writer.write(node_table_dict) - edge_writer = create_writer(self.output_file + "edge_table") + edge_writer = create_writer( + self.output_file + "edge_table", **self.dataset_kwargs + ) src_ids = [] dst_ids = [] weight = [] @@ -170,7 +179,9 @@ def save(self) -> None: def save_predict_edge(self) -> None: """Save edge info for prediction.""" if self.output_file.startswith("odps://"): - writer = create_writer(self.output_file + "predict_edge_table") + writer = create_writer( + self.output_file + "predict_edge_table", **self.dataset_kwargs + ) src_ids = [] dst_ids = [] weight = [] diff --git a/tzrec/tools/tdm/init_tree.py b/tzrec/tools/tdm/init_tree.py index d6d3954..902fade 100644 --- a/tzrec/tools/tdm/init_tree.py +++ b/tzrec/tools/tdm/init_tree.py @@ -65,6 +65,12 @@ default=2, help="The branching factor of the nodes in the tree.", ) + parser.add_argument( + "--odps_data_quota_name", + type=str, + default="pay-as-you-go", + help="maxcompute storage api/tunnel data quota name.", + ) args, extra_args = parser.parse_known_args() generator = TreeGenerator( @@ -75,6 +81,7 @@ raw_attr_fields=args.raw_attr_fields, tree_output_dir=args.tree_output_dir, n_cluster=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) root = generator.generate() logger.info("Tree init done. Start save nodes and edges table.") @@ -82,6 +89,7 @@ output_file=args.node_edge_output_file, root=root, child_num=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) tree_search.save() tree_search.save_predict_edge()