diff --git a/tzrec/datasets/csv_dataset.py b/tzrec/datasets/csv_dataset.py index 3be3d92..6cb035e 100644 --- a/tzrec/datasets/csv_dataset.py +++ b/tzrec/datasets/csv_dataset.py @@ -156,3 +156,4 @@ def close(self) -> None: """Close and commit data.""" if self._writer is not None: self._writer.close() + super().close() diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index 47b192b..dd151db 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -29,6 +29,7 @@ from tzrec.features.feature import BaseFeature from tzrec.protos import data_pb2 from tzrec.utils.load_class import get_register_class_meta +from tzrec.utils.logging_util import logger _DATASET_CLASS_MAP = {} _READER_CLASS_MAP = {} @@ -429,7 +430,12 @@ def write(self, output_dict: OrderedDict[str, pa.Array]) -> None: def close(self) -> None: """Close and commit data.""" - pass + self._lazy_inited = False + + def __del__(self) -> None: + if self._lazy_inited: + # pyre-ignore [16] + logger.warning(f"You should close {self.__class__.__name__} explicitly.") def create_reader( diff --git a/tzrec/datasets/odps_dataset.py b/tzrec/datasets/odps_dataset.py index 0455431..079ee43 100644 --- a/tzrec/datasets/odps_dataset.py +++ b/tzrec/datasets/odps_dataset.py @@ -547,3 +547,4 @@ def close(self) -> None: raise RuntimeError( f"Fail to commit write session: {self._sess_req.session_id}" ) + super().close() diff --git a/tzrec/datasets/parquet_dataset.py b/tzrec/datasets/parquet_dataset.py index 3fde70b..c8f9765 100644 --- a/tzrec/datasets/parquet_dataset.py +++ b/tzrec/datasets/parquet_dataset.py @@ -143,3 +143,4 @@ def close(self) -> None: """Close and commit data.""" if self._writer is not None: self._writer.close() + super().close() diff --git a/tzrec/models/tdm.py b/tzrec/models/tdm.py index 5d25988..8707fd8 100644 --- a/tzrec/models/tdm.py +++ b/tzrec/models/tdm.py @@ -43,6 +43,7 @@ def __init__( ) non_seq_fea_dim = 0 + self.seq_group_name = "" self.non_seq_group_name = [] query_emb_dim = 0 for feature_group in model_config.feature_groups: diff --git a/tzrec/tools/tdm/cluster_tree.py b/tzrec/tools/tdm/cluster_tree.py index c689d0c..b0d7145 100644 --- a/tzrec/tools/tdm/cluster_tree.py +++ b/tzrec/tools/tdm/cluster_tree.py @@ -71,6 +71,12 @@ default=2, help="The branching factor of the nodes in the tree.", ) + parser.add_argument( + "--odps_data_quota_name", + type=str, + default="pay-as-you-go", + help="maxcompute storage api/tunnel data quota name.", + ) args, extra_args = parser.parse_known_args() cluster = TreeCluster( @@ -82,6 +88,7 @@ embedding_field=args.embedding_field, parallel=args.parallel, n_cluster=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) root = cluster.train() logger.info("Tree cluster done. Start save nodes and edges table.") @@ -89,6 +96,7 @@ output_file=args.node_edge_output_file, root=root, child_num=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) tree_search.save() tree_search.save_predict_edge() diff --git a/tzrec/tools/tdm/gen_tree/tree_cluster.py b/tzrec/tools/tdm/gen_tree/tree_cluster.py index cccabc6..83ceef5 100644 --- a/tzrec/tools/tdm/gen_tree/tree_cluster.py +++ b/tzrec/tools/tdm/gen_tree/tree_cluster.py @@ -14,7 +14,7 @@ import os import time from multiprocessing.connection import Connection -from typing import List, Optional +from typing import Any, List, Optional import numpy as np import numpy.typing as npt @@ -49,6 +49,7 @@ def __init__( embedding_field: str = "item_emb", parallel: int = 16, n_cluster: int = 2, + **kwargs: Any, ) -> None: self.item_input_path = item_input_path self.mini_batch = 1024 @@ -73,15 +74,19 @@ def __init__( self.embedding_field = embedding_field + self.dataset_kwargs = {} + if "odps_data_quota_name" in kwargs: + self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] + def _read(self) -> None: t1 = time.time() ids = list() data = list() attrs = list() raw_attrs = list() - reader = create_reader(self.item_input_path, 256) + reader = create_reader(self.item_input_path, 256, **self.dataset_kwargs) for data_dict in reader.to_batches(): - ids += data_dict[self.item_id_field].to_pylist() + ids += data_dict[self.item_id_field].cast(pa.int64()).to_pylist() data += data_dict[self.embedding_field].to_pylist() tmp_attr = [] if self.attr_fields is not None: diff --git a/tzrec/tools/tdm/gen_tree/tree_generator.py b/tzrec/tools/tdm/gen_tree/tree_generator.py index 83eef19..c61be7c 100644 --- a/tzrec/tools/tdm/gen_tree/tree_generator.py +++ b/tzrec/tools/tdm/gen_tree/tree_generator.py @@ -9,7 +9,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Optional, Union import numpy as np import pyarrow as pa @@ -39,6 +39,7 @@ def __init__( raw_attr_fields: Optional[str] = None, tree_output_dir: Optional[str] = None, n_cluster: int = 2, + **kwargs: Any, ) -> None: self.item_input_path = item_input_path self.item_id_field = item_id_field @@ -52,6 +53,10 @@ def __init__( self.tree_output_dir = tree_output_dir self.n_cluster = n_cluster + self.dataset_kwargs = {} + if "odps_data_quota_name" in kwargs: + self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] + def generate(self, save_tree: bool = False) -> TDMTreeClass: """Generate tree.""" item_fea = self._read() @@ -60,9 +65,11 @@ def generate(self, save_tree: bool = False) -> TDMTreeClass: def _read(self) -> Dict[str, List[Union[int, float, str]]]: item_fea = {"ids": [], "cates": [], "attrs": [], "raw_attrs": []} - reader = create_reader(self.item_input_path, 4096) + reader = create_reader(self.item_input_path, 4096, **self.dataset_kwargs) for data_dict in reader.to_batches(): - item_fea["ids"] += data_dict[self.item_id_field].to_pylist() + item_fea["ids"] += ( + data_dict[self.item_id_field].cast(pa.int64()).to_pylist() + ) item_fea["cates"] += ( data_dict[self.cate_id_field] .cast(pa.string()) diff --git a/tzrec/tools/tdm/gen_tree/tree_search_util.py b/tzrec/tools/tdm/gen_tree/tree_search_util.py index 7c6e87b..a6db55b 100644 --- a/tzrec/tools/tdm/gen_tree/tree_search_util.py +++ b/tzrec/tools/tdm/gen_tree/tree_search_util.py @@ -12,7 +12,7 @@ import os import pickle from collections import OrderedDict -from typing import Callable, Iterator, List, Optional, Tuple +from typing import Any, Callable, Iterator, List, Optional, Tuple import pyarrow as pa from anytree.importer.dictimporter import DictImporter @@ -62,6 +62,7 @@ def __init__( tree_path: Optional[str] = None, root: Optional[TDMTreeClass] = None, child_num: int = 2, + **kwargs: Any, ) -> None: self.child_num = child_num if root is not None: @@ -81,6 +82,10 @@ def __init__( self.output_file = output_file + self.dataset_kwargs = {} + if "odps_data_quota_name" in kwargs: + self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"] + self._get_nodes() def _load(self, path: str) -> None: @@ -99,6 +104,9 @@ def _get_nodes(self) -> None: self.max_level = level - 1 self.level_code.append([]) self.level_code[self.max_level].append(node) + logger.info( + f"Tree Level: {self.max_level + 1}, Tree Cluster: {self.child_num}." + ) tree_walker = Walker() logger.info("Begin Travel Tree.") @@ -111,7 +119,9 @@ def _get_nodes(self) -> None: def save(self) -> None: """Save tree info.""" if self.output_file.startswith("odps://"): - node_writer = create_writer(self.output_file + "node_table") + str_list = self.output_file.split("/") + str_list[4] = str_list[4] + "_node_table" + node_writer = create_writer("/".join(str_list), **self.dataset_kwargs) ids = [] weight = [] features = [] @@ -130,8 +140,11 @@ def save(self) -> None: node_table_dict["weight"] = pa.array(weight) node_table_dict["features"] = pa.array(features) node_writer.write(node_table_dict) + node_writer.close() - edge_writer = create_writer(self.output_file + "edge_table") + str_list = self.output_file.split("/") + str_list[4] = str_list[4] + "_edge_table" + edge_writer = create_writer("/".join(str_list), **self.dataset_kwargs) src_ids = [] dst_ids = [] weight = [] @@ -146,6 +159,7 @@ def save(self) -> None: edge_table_dict["dst_id"] = pa.array(dst_ids) edge_table_dict["weight"] = pa.array(weight) edge_writer.write(edge_table_dict) + edge_writer.close() else: if not os.path.exists(self.output_file): @@ -172,7 +186,9 @@ def save(self) -> None: def save_predict_edge(self) -> None: """Save edge info for prediction.""" if self.output_file.startswith("odps://"): - writer = create_writer(self.output_file + "predict_edge_table") + str_list = self.output_file.split("/") + str_list[4] = str_list[4] + "_predict_edge_table" + writer = create_writer("/".join(str_list), **self.dataset_kwargs) src_ids = [] dst_ids = [] weight = [] @@ -187,6 +203,7 @@ def save_predict_edge(self) -> None: edge_table_dict["dst_id"] = pa.array(dst_ids) edge_table_dict["weight"] = pa.array(weight) writer.write(edge_table_dict) + writer.close() else: with open( os.path.join(self.output_file, "predict_edge_table.txt"), "w" diff --git a/tzrec/tools/tdm/init_tree.py b/tzrec/tools/tdm/init_tree.py index d6d3954..902fade 100644 --- a/tzrec/tools/tdm/init_tree.py +++ b/tzrec/tools/tdm/init_tree.py @@ -65,6 +65,12 @@ default=2, help="The branching factor of the nodes in the tree.", ) + parser.add_argument( + "--odps_data_quota_name", + type=str, + default="pay-as-you-go", + help="maxcompute storage api/tunnel data quota name.", + ) args, extra_args = parser.parse_known_args() generator = TreeGenerator( @@ -75,6 +81,7 @@ raw_attr_fields=args.raw_attr_fields, tree_output_dir=args.tree_output_dir, n_cluster=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) root = generator.generate() logger.info("Tree init done. Start save nodes and edges table.") @@ -82,6 +89,7 @@ output_file=args.node_edge_output_file, root=root, child_num=args.n_cluster, + odps_data_quota_name=args.odps_data_quota_name, ) tree_search.save() tree_search.save_predict_edge() diff --git a/tzrec/version.py b/tzrec/version.py index 76aaa2e..7b5720f 100644 --- a/tzrec/version.py +++ b/tzrec/version.py @@ -9,4 +9,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.5.5" +__version__ = "0.5.6"