Skip to content

Commit

Permalink
add quota name and cast id to int64 for tdm tree
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Oct 9, 2024
1 parent 2faf03d commit a80dfb3
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 10 deletions.
8 changes: 8 additions & 0 deletions tzrec/tools/tdm/cluster_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,12 @@
default=2,
help="The branching factor of the nodes in the tree.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()

cluster = TreeCluster(
Expand All @@ -82,13 +88,15 @@
embedding_field=args.embedding_field,
parallel=args.parallel,
n_cluster=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
root = cluster.train()
logger.info("Tree cluster done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
root=root,
child_num=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
tree_search.save()
tree_search.save_predict_edge()
Expand Down
11 changes: 8 additions & 3 deletions tzrec/tools/tdm/gen_tree/tree_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import os
import time
from multiprocessing.connection import Connection
from typing import List, Optional
from typing import Any, List, Optional

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -49,6 +49,7 @@ def __init__(
embedding_field: str = "item_emb",
parallel: int = 16,
n_cluster: int = 2,
**kwargs: Any,
) -> None:
self.item_input_path = item_input_path
self.mini_batch = 1024
Expand All @@ -73,15 +74,19 @@ def __init__(

self.embedding_field = embedding_field

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

def _read(self) -> None:
t1 = time.time()
ids = list()
data = list()
attrs = list()
raw_attrs = list()
reader = create_reader(self.item_input_path, 256)
reader = create_reader(self.item_input_path, 256, **self.dataset_kwargs)
for data_dict in reader.to_batches():
ids += data_dict[self.item_id_field].to_pylist()
ids += data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
data += data_dict[self.embedding_field].to_pylist()
tmp_attr = []
if self.attr_fields is not None:
Expand Down
13 changes: 10 additions & 3 deletions tzrec/tools/tdm/gen_tree/tree_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union

import numpy as np
import pyarrow as pa
Expand Down Expand Up @@ -39,6 +39,7 @@ def __init__(
raw_attr_fields: Optional[str] = None,
tree_output_dir: Optional[str] = None,
n_cluster: int = 2,
**kwargs: Any,
) -> None:
self.item_input_path = item_input_path
self.item_id_field = item_id_field
Expand All @@ -52,6 +53,10 @@ def __init__(
self.tree_output_dir = tree_output_dir
self.n_cluster = n_cluster

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

def generate(self, save_tree: bool = False) -> TDMTreeClass:
"""Generate tree."""
item_fea = self._read()
Expand All @@ -60,9 +65,11 @@ def generate(self, save_tree: bool = False) -> TDMTreeClass:

def _read(self) -> Dict[str, List[Union[int, float, str]]]:
item_fea = {"ids": [], "cates": [], "attrs": [], "raw_attrs": []}
reader = create_reader(self.item_input_path, 4096)
reader = create_reader(self.item_input_path, 4096, **self.dataset_kwargs)
for data_dict in reader.to_batches():
item_fea["ids"] += data_dict[self.item_id_field].to_pylist()
item_fea["ids"] += (
data_dict[self.item_id_field].cast(pa.int64()).to_pylist()
)
item_fea["cates"] += (
data_dict[self.cate_id_field]
.cast(pa.string())
Expand Down
19 changes: 15 additions & 4 deletions tzrec/tools/tdm/gen_tree/tree_search_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import os
import pickle
from collections import OrderedDict
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Any, Callable, Iterator, List, Optional, Tuple

import pyarrow as pa
from anytree.importer.dictimporter import DictImporter
Expand Down Expand Up @@ -62,6 +62,7 @@ def __init__(
tree_path: Optional[str] = None,
root: Optional[TDMTreeClass] = None,
child_num: int = 2,
**kwargs: Any,
) -> None:
self.child_num = child_num
if root is not None:
Expand All @@ -81,6 +82,10 @@ def __init__(

self.output_file = output_file

self.dataset_kwargs = {}
if "odps_data_quota_name" in kwargs:
self.dataset_kwargs["quota_name"] = kwargs["odps_data_quota_name"]

self._get_nodes()

def _load(self, path: str) -> None:
Expand Down Expand Up @@ -111,7 +116,9 @@ def _get_nodes(self) -> None:
def save(self) -> None:
"""Save tree info."""
if self.output_file.startswith("odps://"):
node_writer = create_writer(self.output_file + "node_table")
node_writer = create_writer(
self.output_file + "node_table", **self.dataset_kwargs
)
ids = []
weight = []
features = []
Expand All @@ -131,7 +138,9 @@ def save(self) -> None:
node_table_dict["features"] = pa.array(features)
node_writer.write(node_table_dict)

edge_writer = create_writer(self.output_file + "edge_table")
edge_writer = create_writer(
self.output_file + "edge_table", **self.dataset_kwargs
)
src_ids = []
dst_ids = []
weight = []
Expand Down Expand Up @@ -170,7 +179,9 @@ def save(self) -> None:
def save_predict_edge(self) -> None:
"""Save edge info for prediction."""
if self.output_file.startswith("odps://"):
writer = create_writer(self.output_file + "predict_edge_table")
writer = create_writer(
self.output_file + "predict_edge_table", **self.dataset_kwargs
)
src_ids = []
dst_ids = []
weight = []
Expand Down
8 changes: 8 additions & 0 deletions tzrec/tools/tdm/init_tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,12 @@
default=2,
help="The branching factor of the nodes in the tree.",
)
parser.add_argument(
"--odps_data_quota_name",
type=str,
default="pay-as-you-go",
help="maxcompute storage api/tunnel data quota name.",
)
args, extra_args = parser.parse_known_args()

generator = TreeGenerator(
Expand All @@ -75,13 +81,15 @@
raw_attr_fields=args.raw_attr_fields,
tree_output_dir=args.tree_output_dir,
n_cluster=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
root = generator.generate()
logger.info("Tree init done. Start save nodes and edges table.")
tree_search = TreeSearch(
output_file=args.node_edge_output_file,
root=root,
child_num=args.n_cluster,
odps_data_quota_name=args.odps_data_quota_name,
)
tree_search.save()
tree_search.save_predict_edge()
Expand Down

0 comments on commit a80dfb3

Please sign in to comment.