Skip to content

Commit

Permalink
[feat] add odps quota name and partition for tdm tree build (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji authored Oct 10, 2024
1 parent e25bc04 commit 435ed0a
Show file tree
Hide file tree
Showing 11 changed files with 67 additions and 12 deletions.
1 change: 1 addition & 0 deletions tzrec/datasets/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ def close(self) -> None:
"""Close and commit data."""
if self._writer is not None:
self._writer.close()
super().close()
8 changes: 7 additions & 1 deletion tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tzrec.features.feature import BaseFeature
from tzrec.protos import data_pb2
from tzrec.utils.load_class import get_register_class_meta
from tzrec.utils.logging_util import logger

_DATASET_CLASS_MAP = {}
_READER_CLASS_MAP = {}
Expand Down Expand Up @@ -429,7 +430,12 @@ def write(self, output_dict: OrderedDict[str, pa.Array]) -> None:

def close(self) -> None:
"""Close and commit data."""
pass
self._lazy_inited = False

def __del__(self) -> None:
if self._lazy_inited:
# pyre-ignore [16]
logger.warning(f"You should close {self.__class__.__name__} explicitly.")


def create_reader(
Expand Down
1 change: 1 addition & 0 deletions tzrec/datasets/odps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,4 @@ def close(self) -> None:
raise RuntimeError(
f"Fail to commit write session: {self._sess_req.session_id}"
)
super().close()
1 change: 1 addition & 0 deletions tzrec/datasets/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,4 @@ def close(self) -> None:
"""Close and commit data."""
if self._writer is not None:
self._writer.close()
super().close()
1 change: 1 addition & 0 deletions tzrec/models/tdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
)

non_seq_fea_dim = 0
self.seq_group_name = ""
self.non_seq_group_name = []
query_emb_dim = 0
for feature_group in model_config.feature_groups:
Expand Down
8 changes: 8 additions & 0 deletions tzrec/tools/tdm/cluster_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
default=2,
help="The branching factor of the nodes in the tree.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()

cluster = TreeCluster(
Expand All @@ -82,13 +88,15 @@
embedding_field=args.embedding_field,
parallel=args.parallel,
n_cluster=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
root = cluster.train()
logger.info("Tree cluster done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
root=root,
child_num=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
tree_search.save()
tree_search.save_predict_edge()
Expand Down
11 changes: 8 additions & 3 deletions tzrec/tools/tdm/gen_tree/tree_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import time
from multiprocessing.connection import Connection
from typing import List, Optional
from typing import Any, List, Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(
embedding_field: str = "item_emb",
parallel: int = 16,
n_cluster: int = 2,
**kwargs: Any,
) -> None:
self.item_input_path = item_input_path
self.mini_batch = 1024
Expand All @@ -73,15 +74,19 @@ def __init__(

self.embedding_field = embedding_field

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

def _read(self) -> None:
t1 = time.time()
ids = list()
data = list()
attrs = list()
raw_attrs = list()
reader = create_reader(self.item_input_path, 256)
reader = create_reader(self.item_input_path, 256, **self.dataset_kwargs)
for data_dict in reader.to_batches():
ids += data_dict[self.item_id_field].to_pylist()
ids += data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
data += data_dict[self.embedding_field].to_pylist()
tmp_attr = []
if self.attr_fields is not None:
Expand Down
13 changes: 10 additions & 3 deletions tzrec/tools/tdm/gen_tree/tree_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
raw_attr_fields: Optional[str] = None,
tree_output_dir: Optional[str] = None,
n_cluster: int = 2,
**kwargs: Any,
) -> None:
self.item_input_path = item_input_path
self.item_id_field = item_id_field
Expand All @@ -52,6 +53,10 @@ def __init__(
self.tree_output_dir = tree_output_dir
self.n_cluster = n_cluster

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

def generate(self, save_tree: bool = False) -> TDMTreeClass:
"""Generate tree."""
item_fea = self._read()
Expand All @@ -60,9 +65,11 @@ def generate(self, save_tree: bool = False) -> TDMTreeClass:

def _read(self) -> Dict[str, List[Union[int, float, str]]]:
item_fea = {"ids": [], "cates": [], "attrs": [], "raw_attrs": []}
reader = create_reader(self.item_input_path, 4096)
reader = create_reader(self.item_input_path, 4096, **self.dataset_kwargs)
for data_dict in reader.to_batches():
item_fea["ids"] += data_dict[self.item_id_field].to_pylist()
item_fea["ids"] += (
data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
)
item_fea["cates"] += (
data_dict[self.cate_id_field]
.cast(pa.string())
Expand Down
25 changes: 21 additions & 4 deletions tzrec/tools/tdm/gen_tree/tree_search_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
import pickle
from collections import OrderedDict
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Any, Callable, Iterator, List, Optional, Tuple

import pyarrow as pa
from anytree.importer.dictimporter import DictImporter
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
tree_path: Optional[str] = None,
root: Optional[TDMTreeClass] = None,
child_num: int = 2,
**kwargs: Any,
) -> None:
self.child_num = child_num
if root is not None:
Expand All @@ -81,6 +82,10 @@ def __init__(

self.output_file = output_file

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

self._get_nodes()

def _load(self, path: str) -> None:
Expand All @@ -99,6 +104,9 @@ def _get_nodes(self) -> None:
self.max_level = level - 1
self.level_code.append([])
self.level_code[self.max_level].append(node)
logger.info(
f"Tree Level: {self.max_level + 1}, Tree Cluster: {self.child_num}."
)

tree_walker = Walker()
logger.info("Begin Travel Tree.")
Expand All @@ -111,7 +119,9 @@ def _get_nodes(self) -> None:
def save(self) -> None:
"""Save tree info."""
if self.output_file.startswith("odps://"):
node_writer = create_writer(self.output_file + "node_table")
str_list = self.output_file.split("/")
str_list[4] = str_list[4] + "_node_table"
node_writer = create_writer("/".join(str_list), **self.dataset_kwargs)
ids = []
weight = []
features = []
Expand All @@ -130,8 +140,11 @@ def save(self) -> None:
node_table_dict["weight"] = pa.array(weight)
node_table_dict["features"] = pa.array(features)
node_writer.write(node_table_dict)
node_writer.close()

edge_writer = create_writer(self.output_file + "edge_table")
str_list = self.output_file.split("/")
str_list[4] = str_list[4] + "_edge_table"
edge_writer = create_writer("/".join(str_list), **self.dataset_kwargs)
src_ids = []
dst_ids = []
weight = []
Expand All @@ -146,6 +159,7 @@ def save(self) -> None:
edge_table_dict["dst_id"] = pa.array(dst_ids)
edge_table_dict["weight"] = pa.array(weight)
edge_writer.write(edge_table_dict)
edge_writer.close()

else:
if not os.path.exists(self.output_file):
Expand All @@ -172,7 +186,9 @@ def save(self) -> None:
def save_predict_edge(self) -> None:
"""Save edge info for prediction."""
if self.output_file.startswith("odps://"):
writer = create_writer(self.output_file + "predict_edge_table")
str_list = self.output_file.split("/")
str_list[4] = str_list[4] + "_predict_edge_table"
writer = create_writer("/".join(str_list), **self.dataset_kwargs)
src_ids = []
dst_ids = []
weight = []
Expand All @@ -187,6 +203,7 @@ def save_predict_edge(self) -> None:
edge_table_dict["dst_id"] = pa.array(dst_ids)
edge_table_dict["weight"] = pa.array(weight)
writer.write(edge_table_dict)
writer.close()
else:
with open(
os.path.join(self.output_file, "predict_edge_table.txt"), "w"
Expand Down
8 changes: 8 additions & 0 deletions tzrec/tools/tdm/init_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
default=2,
help="The branching factor of the nodes in the tree.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()

generator = TreeGenerator(
Expand All @@ -75,13 +81,15 @@
raw_attr_fields=args.raw_attr_fields,
tree_output_dir=args.tree_output_dir,
n_cluster=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
root = generator.generate()
logger.info("Tree init done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
root=root,
child_num=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
tree_search.save()
tree_search.save_predict_edge()
Expand Down
2 changes: 1 addition & 1 deletion tzrec/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.5.5"
__version__ = "0.5.6"

0 comments on commit 435ed0a

Please sign in to comment.