Skip to content

Commit

Permalink
add bucketize only mode & refactor fg_encoded to fg_mode
Browse files Browse the repository at this point in the history
  • Loading branch information
tiankongdeguiji committed Dec 13, 2024
1 parent 760221f commit 4d25a8f
Show file tree
Hide file tree
Showing 61 changed files with 293 additions and 181 deletions.
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/criteo/deepfm.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {}
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "label"
num_workers: 8
odps_data_quota_name: ""
Expand Down
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/taobao/dbmtl.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "clk"
label_fields: "buy"
num_workers: 8
Expand Down
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/taobao/dbmtl_has_sequence.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "clk"
label_fields: "buy"
num_workers: 8
Expand Down
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/taobao/dbmtl_jrc.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "clk"
label_fields: "buy"
num_workers: 8
Expand Down
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/taobao/mmoe.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "clk"
label_fields: "buy"
num_workers: 8
Expand Down
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/taobao/mmoe_has_sequence.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "clk"
label_fields: "buy"
num_workers: 8
Expand Down
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/taobao/ple.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "clk"
label_fields: "buy"
num_workers: 8
Expand Down
1 change: 0 additions & 1 deletion tzrec/benchmark/configs/taobao/ple_has_sequence.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: true
label_fields: "clk"
label_fields: "buy"
num_workers: 8
Expand Down
2 changes: 1 addition & 1 deletion tzrec/benchmark/configs/taobao_ccp/dbmtl.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: false
fg_mode: FG_DAG
label_fields: "click_label"
label_fields: "conversion_label"
num_workers: 8
Expand Down
2 changes: 1 addition & 1 deletion tzrec/benchmark/configs/taobao_ccp/mmoe.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: false
fg_mode: FG_DAG
label_fields: "click_label"
label_fields: "conversion_label"
num_workers: 8
Expand Down
2 changes: 1 addition & 1 deletion tzrec/benchmark/configs/taobao_ccp/ple.config
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ eval_config {
data_config {
batch_size: 8192
dataset_type: OdpsDataset
fg_encoded: false
fg_mode: FG_DAG
label_fields: "click_label"
label_fields: "conversion_label"
num_workers: 8
Expand Down
6 changes: 3 additions & 3 deletions tzrec/datasets/csv_dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_csv_dataset(self, with_header, num_rows):
data_config = data_pb2.DataConfig(
batch_size=4,
dataset_type=data_pb2.DatasetType.CsvDataset,
fg_encoded=True,
fg_mode=data_pb2.FgMode.FG_NONE,
label_fields=["label"],
with_header=with_header,
)
Expand Down Expand Up @@ -134,7 +134,7 @@ def test_csv_dataset_with_all_nulls(self):
)
),
]
features = create_features(feature_cfgs, fg_mode=FgMode.DAG)
features = create_features(feature_cfgs, fg_mode=FgMode.FG_DAG)

t = pa.Table.from_arrays(
[
Expand All @@ -157,7 +157,7 @@ def test_csv_dataset_with_all_nulls(self):
data_config = data_pb2.DataConfig(
batch_size=4,
dataset_type=data_pb2.DatasetType.CsvDataset,
fg_encoded=False,
fg_mode=data_pb2.FgMode.FG_DAG,
label_fields=["label"],
)
data_config.input_fields.extend(
Expand Down
19 changes: 12 additions & 7 deletions tzrec/datasets/data_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,11 @@ def __init__(
if feature.is_weighted:
self.has_weight_keys[feature.data_group].append(feature.name)

self.feature_input_names = set()
if self._fg_mode == FgMode.DAG:
if self._fg_mode in [FgMode.FG_DAG, FgMode.FG_BUCKETIZE]:
self._init_fg_hander()

self.feature_input_names = set()
if self._fg_mode == FgMode.FG_DAG:
self.feature_input_names = (
self._fg_handler.user_inputs()
| self._fg_handler.item_inputs()
Expand All @@ -118,8 +120,11 @@ def _init_fg_hander(self) -> None:
"""Init pyfg dag handler."""
if not self._fg_handler:
fg_json = create_fg_json(self._features)
bucketize_only = self._fg_mode == FgMode.FG_BUCKETIZE
# pyre-ignore [16]
self._fg_handler = pyfg.FgArrowHandler(fg_json, self._fg_threads)
self._fg_handler = pyfg.FgArrowHandler(
fg_json, self._fg_threads, bucketize_only=bucketize_only
)

def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:
"""Parse input data dict and build batch.
Expand All @@ -134,7 +139,7 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:
if is_input_tile():
flag = False
for k, v in input_data.items():
if self._fg_mode == FgMode.ENCODED:
if self._fg_mode == FgMode.FG_NONE:
if k in self.user_feats:
input_data[k] = v.take([0])
else:
Expand All @@ -144,8 +149,8 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]:
output_data["batch_size"] = torch.tensor(v.__len__())
flag = True

if self._fg_mode == FgMode.DAG:
self._parse_feature_fg_dag(input_data, output_data)
if self._fg_mode in (FgMode.FG_DAG, FgMode.FG_BUCKETIZE):
self._parse_feature_fg_handler(input_data, output_data)
else:
self._parse_feature_normal(input_data, output_data)

Expand Down Expand Up @@ -228,7 +233,7 @@ def _parse_feature_normal(
)
output_data[f"{feature.name}.values"] = _to_tensor(feat_data.values)

def _parse_feature_fg_dag(
def _parse_feature_fg_handler(
self, input_data: Dict[str, pa.Array], output_data: Dict[str, torch.Tensor]
) -> None:
max_batch_size = (
Expand Down
74 changes: 65 additions & 9 deletions tzrec/datasets/data_parser_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ class DataParserTest(unittest.TestCase):
def tearDown(self):
os.environ.pop("INPUT_TILE", None)

def test_fg_encoded(self):
def test_fg_none(self):
feature_cfgs = [
feature_pb2.FeatureConfig(
id_feature=feature_pb2.IdFeature(
Expand Down Expand Up @@ -404,10 +404,10 @@ def _create_test_fg_feature_cfgs(self, tag_b_weighted=False):

@parameterized.expand(
[
[FgMode.NORMAL, False],
[FgMode.DAG, False],
[FgMode.NORMAL, True],
[FgMode.DAG, True],
[FgMode.FG_NORMAL, False],
[FgMode.FG_DAG, False],
[FgMode.FG_NORMAL, True],
[FgMode.FG_DAG, True],
]
)
def test_fg(self, fg_mode, weigted_id):
Expand Down Expand Up @@ -527,6 +527,62 @@ def test_fg(self, fg_mode, weigted_id):
)
torch.testing.assert_close(batch.labels["label"], expected_label)

def test_fg_bucketize_only(self):
feature_cfgs = self._create_test_fg_feature_cfgs()
features = create_features(feature_cfgs, fg_mode=FgMode.FG_BUCKETIZE)
data_parser = DataParser(features=features, labels=["label"])
data = data_parser.parse(
input_data={
"cat_a": pa.array([["1"], ["2"], ["3"]]),
"tag_b": pa.array([["4", "5"], [], ["6"]]),
"int_a": pa.array([7, 8, 9], pa.float32()),
"int_b": pa.array([[27, 37], [28, 38], [29, 39]]),
"lookup_a": pa.array([0.1, 0.0, 0.2], type=pa.float32()),
"click_seq__cat_a": pa.array([["10", "11", "12"], ["13"], ["0"]]),
"click_seq__int_a": pa.array([["14", "15", "16"], ["17"], ["0"]]),
"label": pa.array([0, 0, 1], pa.int32()),
}
)

expected_cat_a_values = torch.tensor([1, 2, 3], dtype=torch.int64)
expected_cat_a_lengths = torch.tensor([1, 1, 1], dtype=torch.int32)
expected_tag_b_values = torch.tensor([4, 5, 6], dtype=torch.int64)
expected_tag_b_lengths = torch.tensor([2, 0, 1], dtype=torch.int32)
expected_int_a_values = torch.tensor([[7], [8], [9]], dtype=torch.float32)
expected_int_b_values = torch.tensor(
[[27, 37], [28, 38], [29, 39]], dtype=torch.float32
)
expected_lookup_a_values = torch.tensor(
[[0.1], [0.0], [0.2]], dtype=torch.float32
)
expected_seq_cat_a_values = torch.tensor([10, 11, 12, 13, 0], dtype=torch.int64)
expected_seq_cat_a_seq_lengths = torch.tensor([3, 1, 1], dtype=torch.int32)
expected_seq_int_a_values = torch.tensor(
[[14], [15], [16], [17], [0]], dtype=torch.float32
)
expected_seq_int_a_seq_lengths = torch.tensor([3, 1, 1], dtype=torch.int32)
expected_label = torch.tensor([0, 0, 1], dtype=torch.int64)
torch.testing.assert_close(data["cat_a.values"], expected_cat_a_values)
torch.testing.assert_close(data["cat_a.lengths"], expected_cat_a_lengths)
torch.testing.assert_close(data["tag_b.values"], expected_tag_b_values)
torch.testing.assert_close(data["tag_b.lengths"], expected_tag_b_lengths)
torch.testing.assert_close(data["int_a.values"], expected_int_a_values)
torch.testing.assert_close(data["int_b.values"], expected_int_b_values)
torch.testing.assert_close(data["lookup_a.values"], expected_lookup_a_values)
torch.testing.assert_close(
data["click_seq__cat_a.values"], expected_seq_cat_a_values
)
torch.testing.assert_close(
data["click_seq__cat_a.lengths"], expected_seq_cat_a_seq_lengths
)
torch.testing.assert_close(
data["click_seq__int_a.values"], expected_seq_int_a_values
)
torch.testing.assert_close(
data["click_seq__int_a.lengths"], expected_seq_int_a_seq_lengths
)
torch.testing.assert_close(data["label"], expected_label)

@parameterized.expand(
[
[
Expand All @@ -540,7 +596,7 @@ def test_fg(self, fg_mode, weigted_id):
"click_seq__int_a": pa.array(["14;15;16", "17", ""]),
"label": pa.array([0, 0, 1], pa.int32()),
},
FgMode.ENCODED,
FgMode.FG_NONE,
],
[
{
Expand All @@ -553,7 +609,7 @@ def test_fg(self, fg_mode, weigted_id):
"click_seq__int_a": pa.array(["14;15;16", "17", ""]),
"label": pa.array([0, 0, 1], pa.int32()),
},
FgMode.DAG,
FgMode.FG_DAG,
],
]
)
Expand Down Expand Up @@ -664,7 +720,7 @@ def test_input_tile(self, input_data, fg_mode):
"click_seq__int_a": pa.array(["14;15;16", "17", ""]),
"label": pa.array([0, 0, 1], pa.int32()),
},
FgMode.ENCODED,
FgMode.FG_NONE,
],
[
{
Expand All @@ -677,7 +733,7 @@ def test_input_tile(self, input_data, fg_mode):
"click_seq__int_a": pa.array(["14;15;16", "17", ""]),
"label": pa.array([0, 0, 1], pa.int32()),
},
FgMode.DAG,
FgMode.FG_DAG,
],
]
)
Expand Down
4 changes: 2 additions & 2 deletions tzrec/datasets/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ def __init__(
):
self._selected_input_names = None

self._fg_encoded = data_config.fg_encoded
self._fg_mode = data_config.fg_mode
self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep

if mode != Mode.TRAIN and data_config.HasField("eval_batch_size"):
Expand Down Expand Up @@ -235,7 +235,7 @@ def launch_sampler_cluster(
self._batch_size,
is_training=self._mode == Mode.TRAIN,
multival_sep=self._fg_encoded_multival_sep
if self._fg_encoded
if self._fg_mode == data_pb2.FgMode.FG_NONE
else chr(29),
)
self._sampler.init_cluster(num_client_per_rank, client_id_bias, cluster)
Expand Down
12 changes: 6 additions & 6 deletions tzrec/datasets/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def test_dataset(self):
data_config=data_pb2.DataConfig(
batch_size=4,
dataset_type=data_pb2.DatasetType.OdpsDataset,
fg_encoded=True,
fg_mode=data_pb2.FgModel.FG_NONE,
label_fields=["label"],
),
features=features,
Expand Down Expand Up @@ -189,7 +189,7 @@ def test_dataset_with_sampler(self, force_base_data_group):
data_config=data_pb2.DataConfig(
batch_size=4,
dataset_type=data_pb2.DatasetType.OdpsDataset,
fg_encoded=True,
fg_mode=data_pb2.FgModel.FG_NONE,
label_fields=["label"],
negative_sampler=sampler_pb2.NegativeSampler(
input_path=f.name,
Expand Down Expand Up @@ -284,7 +284,7 @@ def test_dataset_with_sample_mask(self):
data_config=data_pb2.DataConfig(
batch_size=32,
dataset_type=data_pb2.DatasetType.OdpsDataset,
fg_encoded=True,
fg_mode=data_pb2.FgModel.FG_NONE,
label_fields=["label"],
sample_mask_prob=0.4,
),
Expand Down Expand Up @@ -348,7 +348,7 @@ def test_dataset_with_neg_sample_mask(self):
data_config=data_pb2.DataConfig(
batch_size=32,
dataset_type=data_pb2.DatasetType.OdpsDataset,
fg_encoded=True,
fg_mode=data_pb2.FgModel.FG_NONE,
label_fields=["label"],
negative_sample_mask_prob=0.4,
negative_sampler=sampler_pb2.NegativeSampler(
Expand Down Expand Up @@ -408,7 +408,7 @@ def test_dataset_predict_mode(self, debug_level):
data_config=data_pb2.DataConfig(
batch_size=4,
dataset_type=data_pb2.DatasetType.OdpsDataset,
fg_encoded=True,
fg_mode=data_pb2.FgModel.FG_NONE,
label_fields=[],
),
features=features,
Expand Down Expand Up @@ -501,7 +501,7 @@ def _childern(code):
data_config=data_pb2.DataConfig(
batch_size=4,
dataset_type=data_pb2.DatasetType.OdpsDataset,
fg_encoded=True,
fg_mode=data_pb2.FgModel.FG_NONE,
label_fields=["label"],
tdm_sampler=sampler_pb2.TDMSampler(
item_input_path=node.name,
Expand Down
Loading

0 comments on commit 4d25a8f

Please sign in to comment.