Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] add odps quota name and partition for tdm tree build #6

Merged
merged 8 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tzrec/datasets/csv_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,4 @@ def close(self) -> None:
"""Close and commit data."""
if self._writer is not None:
self._writer.close()
super().close()
8 changes: 7 additions & 1 deletion tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from tzrec.features.feature import BaseFeature
from tzrec.protos import data_pb2
from tzrec.utils.load_class import get_register_class_meta
from tzrec.utils.logging_util import logger

_DATASET_CLASS_MAP = {}
_READER_CLASS_MAP = {}
Expand Down Expand Up @@ -429,7 +430,12 @@ def write(self, output_dict: OrderedDict[str, pa.Array]) -> None:

def close(self) -> None:
"""Close and commit data."""
pass
self._lazy_inited = False

def __del__(self) -> None:
if self._lazy_inited:
# pyre-ignore [16]
logger.warning(f"You should close {self.__class__.__name__} explicitly.")


def create_reader(
Expand Down
1 change: 1 addition & 0 deletions tzrec/datasets/odps_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,4 @@ def close(self) -> None:
raise RuntimeError(
f"Fail to commit write session: {self._sess_req.session_id}"
)
super().close()
1 change: 1 addition & 0 deletions tzrec/datasets/parquet_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,4 @@ def close(self) -> None:
"""Close and commit data."""
if self._writer is not None:
self._writer.close()
super().close()
1 change: 1 addition & 0 deletions tzrec/models/tdm.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ def __init__(
)

non_seq_fea_dim = 0
self.seq_group_name = ""
self.non_seq_group_name = []
query_emb_dim = 0
for feature_group in model_config.feature_groups:
Expand Down
8 changes: 8 additions & 0 deletions tzrec/tools/tdm/cluster_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
default=2,
help="The branching factor of the nodes in the tree.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()

cluster = TreeCluster(
Expand All @@ -82,13 +88,15 @@
embedding_field=args.embedding_field,
parallel=args.parallel,
n_cluster=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
root = cluster.train()
logger.info("Tree cluster done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
root=root,
child_num=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
tree_search.save()
tree_search.save_predict_edge()
Expand Down
11 changes: 8 additions & 3 deletions tzrec/tools/tdm/gen_tree/tree_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import time
from multiprocessing.connection import Connection
from typing import List, Optional
from typing import Any, List, Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(
embedding_field: str = "item_emb",
parallel: int = 16,
n_cluster: int = 2,
**kwargs: Any,
) -> None:
self.item_input_path = item_input_path
self.mini_batch = 1024
Expand All @@ -73,15 +74,19 @@ def __init__(

self.embedding_field = embedding_field

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

def _read(self) -> None:
t1 = time.time()
ids = list()
data = list()
attrs = list()
raw_attrs = list()
reader = create_reader(self.item_input_path, 256)
reader = create_reader(self.item_input_path, 256, **self.dataset_kwargs)
for data_dict in reader.to_batches():
ids += data_dict[self.item_id_field].to_pylist()
ids += data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
data += data_dict[self.embedding_field].to_pylist()
tmp_attr = []
if self.attr_fields is not None:
Expand Down
13 changes: 10 additions & 3 deletions tzrec/tools/tdm/gen_tree/tree_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
raw_attr_fields: Optional[str] = None,
tree_output_dir: Optional[str] = None,
n_cluster: int = 2,
**kwargs: Any,
) -> None:
self.item_input_path = item_input_path
self.item_id_field = item_id_field
Expand All @@ -52,6 +53,10 @@ def __init__(
self.tree_output_dir = tree_output_dir
self.n_cluster = n_cluster

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

def generate(self, save_tree: bool = False) -> TDMTreeClass:
"""Generate tree."""
item_fea = self._read()
Expand All @@ -60,9 +65,11 @@ def generate(self, save_tree: bool = False) -> TDMTreeClass:

def _read(self) -> Dict[str, List[Union[int, float, str]]]:
item_fea = {"ids": [], "cates": [], "attrs": [], "raw_attrs": []}
reader = create_reader(self.item_input_path, 4096)
reader = create_reader(self.item_input_path, 4096, **self.dataset_kwargs)
for data_dict in reader.to_batches():
item_fea["ids"] += data_dict[self.item_id_field].to_pylist()
item_fea["ids"] += (
data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
)
item_fea["cates"] += (
data_dict[self.cate_id_field]
.cast(pa.string())
Expand Down
25 changes: 21 additions & 4 deletions tzrec/tools/tdm/gen_tree/tree_search_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
import pickle
from collections import OrderedDict
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Any, Callable, Iterator, List, Optional, Tuple

import pyarrow as pa
from anytree.importer.dictimporter import DictImporter
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
tree_path: Optional[str] = None,
root: Optional[TDMTreeClass] = None,
child_num: int = 2,
**kwargs: Any,
) -> None:
self.child_num = child_num
if root is not None:
Expand All @@ -81,6 +82,10 @@ def __init__(

self.output_file = output_file

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

self._get_nodes()

def _load(self, path: str) -> None:
Expand All @@ -99,6 +104,9 @@ def _get_nodes(self) -> None:
self.max_level = level - 1
self.level_code.append([])
self.level_code[self.max_level].append(node)
logger.info(
f"Tree Level: {self.max_level + 1}, Tree Cluster: {self.child_num}."
)

tree_walker = Walker()
logger.info("Begin Travel Tree.")
Expand All @@ -111,7 +119,9 @@ def _get_nodes(self) -> None:
def save(self) -> None:
"""Save tree info."""
if self.output_file.startswith("odps://"):
node_writer = create_writer(self.output_file + "node_table")
str_list = self.output_file.split("/")
str_list[4] = str_list[4] + "_node_table"
node_writer = create_writer("/".join(str_list), **self.dataset_kwargs)
ids = []
weight = []
features = []
Expand All @@ -130,8 +140,11 @@ def save(self) -> None:
node_table_dict["weight"] = pa.array(weight)
node_table_dict["features"] = pa.array(features)
node_writer.write(node_table_dict)
node_writer.close()

edge_writer = create_writer(self.output_file + "edge_table")
str_list = self.output_file.split("/")
str_list[4] = str_list[4] + "_edge_table"
edge_writer = create_writer("/".join(str_list), **self.dataset_kwargs)
src_ids = []
dst_ids = []
weight = []
Expand All @@ -146,6 +159,7 @@ def save(self) -> None:
edge_table_dict["dst_id"] = pa.array(dst_ids)
edge_table_dict["weight"] = pa.array(weight)
edge_writer.write(edge_table_dict)
edge_writer.close()

else:
if not os.path.exists(self.output_file):
Expand All @@ -172,7 +186,9 @@ def save(self) -> None:
def save_predict_edge(self) -> None:
"""Save edge info for prediction."""
if self.output_file.startswith("odps://"):
writer = create_writer(self.output_file + "predict_edge_table")
str_list = self.output_file.split("/")
str_list[4] = str_list[4] + "_predict_edge_table"
writer = create_writer("/".join(str_list), **self.dataset_kwargs)
src_ids = []
dst_ids = []
weight = []
Expand All @@ -187,6 +203,7 @@ def save_predict_edge(self) -> None:
edge_table_dict["dst_id"] = pa.array(dst_ids)
edge_table_dict["weight"] = pa.array(weight)
writer.write(edge_table_dict)
writer.close()
else:
with open(
os.path.join(self.output_file, "predict_edge_table.txt"), "w"
Expand Down
8 changes: 8 additions & 0 deletions tzrec/tools/tdm/init_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
default=2,
help="The branching factor of the nodes in the tree.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()

generator = TreeGenerator(
Expand All @@ -75,13 +81,15 @@
raw_attr_fields=args.raw_attr_fields,
tree_output_dir=args.tree_output_dir,
n_cluster=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
root = generator.generate()
logger.info("Tree init done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
root=root,
child_num=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
tree_search.save()
tree_search.save_predict_edge()
Expand Down
2 changes: 1 addition & 1 deletion tzrec/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,4 @@
# See the License for the specific language governing permissions and
# limitations under the License.

__version__ = "0.5.5"
__version__ = "0.5.6"
Loading