From 4d25a8f422736852da5cd61c09b23fd85c8cb835 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=A4=A9=E9=82=91?= Date: Fri, 13 Dec 2024 11:27:17 +0800 Subject: [PATCH] add bucketize only mode & refactor fg_encoded to fg_mode --- tzrec/benchmark/configs/criteo/deepfm.config | 1 - tzrec/benchmark/configs/taobao/dbmtl.config | 1 - .../configs/taobao/dbmtl_has_sequence.config | 1 - .../benchmark/configs/taobao/dbmtl_jrc.config | 1 - tzrec/benchmark/configs/taobao/mmoe.config | 1 - .../configs/taobao/mmoe_has_sequence.config | 1 - tzrec/benchmark/configs/taobao/ple.config | 1 - .../configs/taobao/ple_has_sequence.config | 1 - .../benchmark/configs/taobao_ccp/dbmtl.config | 2 +- .../benchmark/configs/taobao_ccp/mmoe.config | 2 +- tzrec/benchmark/configs/taobao_ccp/ple.config | 2 +- tzrec/datasets/csv_dataset_test.py | 6 +- tzrec/datasets/data_parser.py | 19 +++-- tzrec/datasets/data_parser_test.py | 74 ++++++++++++++++--- tzrec/datasets/dataset.py | 4 +- tzrec/datasets/dataset_test.py | 12 +-- tzrec/datasets/odps_dataset_test.py | 8 +- tzrec/datasets/odps_dataset_v1_test.py | 2 +- tzrec/datasets/parquet_dataset_test.py | 2 +- tzrec/features/combo_feature.py | 12 ++- tzrec/features/combo_feature_test.py | 4 +- tzrec/features/expr_feature.py | 12 ++- tzrec/features/expr_feature_test.py | 12 ++- tzrec/features/feature.py | 26 ++----- tzrec/features/feature_test.py | 4 +- tzrec/features/id_feature.py | 14 ++-- tzrec/features/id_feature_test.py | 10 +-- tzrec/features/lookup_feature.py | 12 ++- tzrec/features/lookup_feature_test.py | 8 +- tzrec/features/match_feature.py | 12 ++- tzrec/features/match_feature_test.py | 8 +- tzrec/features/overlap_feature.py | 12 ++- tzrec/features/overlap_feature_test.py | 4 +- tzrec/features/raw_feature.py | 12 ++- tzrec/features/raw_feature_test.py | 4 +- tzrec/features/sequence_feature.py | 24 ++++-- tzrec/features/sequence_feature_test.py | 14 ++-- tzrec/features/tokenize_feature.py | 12 ++- tzrec/features/tokenize_feature_test.py | 4 +- tzrec/main.py | 9 +-- tzrec/predict.py | 2 +- tzrec/protos/data.proto | 19 +++++ tzrec/protos/feature.proto | 20 ++--- .../configs/dbmtl_has_sequence_mock.config | 1 - ...s_sequence_variational_dropout_mock.config | 1 - tzrec/tests/configs/dssm_fg_mock.config | 2 +- tzrec/tests/configs/dssm_mock.config | 1 - tzrec/tests/configs/dssm_v2_fg_mock.config | 2 +- .../dssm_variational_dropout_mock.config | 1 - .../configs/multi_tower_din_fg_mock.config | 2 +- .../tests/configs/multi_tower_din_mock.config | 1 - .../multi_tower_din_trt_fg_mock.config | 2 +- .../multi_tower_din_zch_fg_mock.config | 2 +- .../multi_tower_din_zch_trt_fg_mock.config | 2 +- tzrec/tests/configs/tdm_fg_mock.config | 2 +- tzrec/tests/match_integration_test.py | 4 +- tzrec/tests/rank_integration_test.py | 8 +- tzrec/tests/utils.py | 8 +- .../convert_easyrec_config_to_tzrec_config.py | 1 - ...ert_easyrec_config_to_tzrec_config_test.py | 1 - tzrec/utils/config_util.py | 22 ++++++ 61 files changed, 293 insertions(+), 181 deletions(-) diff --git a/tzrec/benchmark/configs/criteo/deepfm.config b/tzrec/benchmark/configs/criteo/deepfm.config index 066c0a1..26e78d3 100644 --- a/tzrec/benchmark/configs/criteo/deepfm.config +++ b/tzrec/benchmark/configs/criteo/deepfm.config @@ -23,7 +23,6 @@ eval_config {} data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "label" num_workers: 8 odps_data_quota_name: "" diff --git a/tzrec/benchmark/configs/taobao/dbmtl.config b/tzrec/benchmark/configs/taobao/dbmtl.config index 44720d0..6578b7b 100644 --- a/tzrec/benchmark/configs/taobao/dbmtl.config +++ b/tzrec/benchmark/configs/taobao/dbmtl.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao/dbmtl_has_sequence.config b/tzrec/benchmark/configs/taobao/dbmtl_has_sequence.config index e7385c2..327d38c 100644 --- a/tzrec/benchmark/configs/taobao/dbmtl_has_sequence.config +++ b/tzrec/benchmark/configs/taobao/dbmtl_has_sequence.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao/dbmtl_jrc.config b/tzrec/benchmark/configs/taobao/dbmtl_jrc.config index d77509c..420938a 100644 --- a/tzrec/benchmark/configs/taobao/dbmtl_jrc.config +++ b/tzrec/benchmark/configs/taobao/dbmtl_jrc.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao/mmoe.config b/tzrec/benchmark/configs/taobao/mmoe.config index 952f267..b4fd4c3 100644 --- a/tzrec/benchmark/configs/taobao/mmoe.config +++ b/tzrec/benchmark/configs/taobao/mmoe.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao/mmoe_has_sequence.config b/tzrec/benchmark/configs/taobao/mmoe_has_sequence.config index aef7f65..3bee2fe 100644 --- a/tzrec/benchmark/configs/taobao/mmoe_has_sequence.config +++ b/tzrec/benchmark/configs/taobao/mmoe_has_sequence.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao/ple.config b/tzrec/benchmark/configs/taobao/ple.config index ca2b7b2..c75d8b6 100644 --- a/tzrec/benchmark/configs/taobao/ple.config +++ b/tzrec/benchmark/configs/taobao/ple.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao/ple_has_sequence.config b/tzrec/benchmark/configs/taobao/ple_has_sequence.config index c806181..324223f 100644 --- a/tzrec/benchmark/configs/taobao/ple_has_sequence.config +++ b/tzrec/benchmark/configs/taobao/ple_has_sequence.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao_ccp/dbmtl.config b/tzrec/benchmark/configs/taobao_ccp/dbmtl.config index 449ca41..bf2567a 100644 --- a/tzrec/benchmark/configs/taobao_ccp/dbmtl.config +++ b/tzrec/benchmark/configs/taobao_ccp/dbmtl.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "click_label" label_fields: "conversion_label" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao_ccp/mmoe.config b/tzrec/benchmark/configs/taobao_ccp/mmoe.config index 49475bb..f5ccea2 100644 --- a/tzrec/benchmark/configs/taobao_ccp/mmoe.config +++ b/tzrec/benchmark/configs/taobao_ccp/mmoe.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "click_label" label_fields: "conversion_label" num_workers: 8 diff --git a/tzrec/benchmark/configs/taobao_ccp/ple.config b/tzrec/benchmark/configs/taobao_ccp/ple.config index 29d57a0..f717ebf 100644 --- a/tzrec/benchmark/configs/taobao_ccp/ple.config +++ b/tzrec/benchmark/configs/taobao_ccp/ple.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: OdpsDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "click_label" label_fields: "conversion_label" num_workers: 8 diff --git a/tzrec/datasets/csv_dataset_test.py b/tzrec/datasets/csv_dataset_test.py index b9ade69..e13174d 100644 --- a/tzrec/datasets/csv_dataset_test.py +++ b/tzrec/datasets/csv_dataset_test.py @@ -67,7 +67,7 @@ def test_csv_dataset(self, with_header, num_rows): data_config = data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.CsvDataset, - fg_encoded=True, + fg_mode=data_pb2.FgMode.FG_NONE, label_fields=["label"], with_header=with_header, ) @@ -134,7 +134,7 @@ def test_csv_dataset_with_all_nulls(self): ) ), ] - features = create_features(feature_cfgs, fg_mode=FgMode.DAG) + features = create_features(feature_cfgs, fg_mode=FgMode.FG_DAG) t = pa.Table.from_arrays( [ @@ -157,7 +157,7 @@ def test_csv_dataset_with_all_nulls(self): data_config = data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.CsvDataset, - fg_encoded=False, + fg_mode=data_pb2.FgMode.FG_DAG, label_fields=["label"], ) data_config.input_fields.extend( diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index b7c38f9..3fedfa4 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -89,9 +89,11 @@ def __init__( if feature.is_weighted: self.has_weight_keys[feature.data_group].append(feature.name) - self.feature_input_names = set() - if self._fg_mode == FgMode.DAG: + if self._fg_mode in [FgMode.FG_DAG, FgMode.FG_BUCKETIZE]: self._init_fg_hander() + + self.feature_input_names = set() + if self._fg_mode == FgMode.FG_DAG: self.feature_input_names = ( self._fg_handler.user_inputs() | self._fg_handler.item_inputs() @@ -118,8 +120,11 @@ def _init_fg_hander(self) -> None: """Init pyfg dag handler.""" if not self._fg_handler: fg_json = create_fg_json(self._features) + bucketize_only = self._fg_mode == FgMode.FG_BUCKETIZE # pyre-ignore [16] - self._fg_handler = pyfg.FgArrowHandler(fg_json, self._fg_threads) + self._fg_handler = pyfg.FgArrowHandler( + fg_json, self._fg_threads, bucketize_only=bucketize_only + ) def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: """Parse input data dict and build batch. @@ -134,7 +139,7 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: if is_input_tile(): flag = False for k, v in input_data.items(): - if self._fg_mode == FgMode.ENCODED: + if self._fg_mode == FgMode.FG_NONE: if k in self.user_feats: input_data[k] = v.take([0]) else: @@ -144,8 +149,8 @@ def parse(self, input_data: Dict[str, pa.Array]) -> Dict[str, torch.Tensor]: output_data["batch_size"] = torch.tensor(v.__len__()) flag = True - if self._fg_mode == FgMode.DAG: - self._parse_feature_fg_dag(input_data, output_data) + if self._fg_mode in (FgMode.FG_DAG, FgMode.FG_BUCKETIZE): + self._parse_feature_fg_handler(input_data, output_data) else: self._parse_feature_normal(input_data, output_data) @@ -228,7 +233,7 @@ def _parse_feature_normal( ) output_data[f"{feature.name}.values"] = _to_tensor(feat_data.values) - def _parse_feature_fg_dag( + def _parse_feature_fg_handler( self, input_data: Dict[str, pa.Array], output_data: Dict[str, torch.Tensor] ) -> None: max_batch_size = ( diff --git a/tzrec/datasets/data_parser_test.py b/tzrec/datasets/data_parser_test.py index 83ca1dd..a91863c 100644 --- a/tzrec/datasets/data_parser_test.py +++ b/tzrec/datasets/data_parser_test.py @@ -33,7 +33,7 @@ class DataParserTest(unittest.TestCase): def tearDown(self): os.environ.pop("INPUT_TILE", None) - def test_fg_encoded(self): + def test_fg_none(self): feature_cfgs = [ feature_pb2.FeatureConfig( id_feature=feature_pb2.IdFeature( @@ -404,10 +404,10 @@ def _create_test_fg_feature_cfgs(self, tag_b_weighted=False): @parameterized.expand( [ - [FgMode.NORMAL, False], - [FgMode.DAG, False], - [FgMode.NORMAL, True], - [FgMode.DAG, True], + [FgMode.FG_NORMAL, False], + [FgMode.FG_DAG, False], + [FgMode.FG_NORMAL, True], + [FgMode.FG_DAG, True], ] ) def test_fg(self, fg_mode, weigted_id): @@ -527,6 +527,62 @@ def test_fg(self, fg_mode, weigted_id): ) torch.testing.assert_close(batch.labels["label"], expected_label) + def test_fg_bucketize_only(self): + feature_cfgs = self._create_test_fg_feature_cfgs() + features = create_features(feature_cfgs, fg_mode=FgMode.FG_BUCKETIZE) + data_parser = DataParser(features=features, labels=["label"]) + data = data_parser.parse( + input_data={ + "cat_a": pa.array([["1"], ["2"], ["3"]]), + "tag_b": pa.array([["4", "5"], [], ["6"]]), + "int_a": pa.array([7, 8, 9], pa.float32()), + "int_b": pa.array([[27, 37], [28, 38], [29, 39]]), + "lookup_a": pa.array([0.1, 0.0, 0.2], type=pa.float32()), + "click_seq__cat_a": pa.array([["10", "11", "12"], ["13"], ["0"]]), + "click_seq__int_a": pa.array([["14", "15", "16"], ["17"], ["0"]]), + "label": pa.array([0, 0, 1], pa.int32()), + } + ) + + expected_cat_a_values = torch.tensor([1, 2, 3], dtype=torch.int64) + expected_cat_a_lengths = torch.tensor([1, 1, 1], dtype=torch.int32) + expected_tag_b_values = torch.tensor([4, 5, 6], dtype=torch.int64) + expected_tag_b_lengths = torch.tensor([2, 0, 1], dtype=torch.int32) + expected_int_a_values = torch.tensor([[7], [8], [9]], dtype=torch.float32) + expected_int_b_values = torch.tensor( + [[27, 37], [28, 38], [29, 39]], dtype=torch.float32 + ) + expected_lookup_a_values = torch.tensor( + [[0.1], [0.0], [0.2]], dtype=torch.float32 + ) + expected_seq_cat_a_values = torch.tensor([10, 11, 12, 13, 0], dtype=torch.int64) + expected_seq_cat_a_seq_lengths = torch.tensor([3, 1, 1], dtype=torch.int32) + expected_seq_int_a_values = torch.tensor( + [[14], [15], [16], [17], [0]], dtype=torch.float32 + ) + expected_seq_int_a_seq_lengths = torch.tensor([3, 1, 1], dtype=torch.int32) + expected_label = torch.tensor([0, 0, 1], dtype=torch.int64) + torch.testing.assert_close(data["cat_a.values"], expected_cat_a_values) + torch.testing.assert_close(data["cat_a.lengths"], expected_cat_a_lengths) + torch.testing.assert_close(data["tag_b.values"], expected_tag_b_values) + torch.testing.assert_close(data["tag_b.lengths"], expected_tag_b_lengths) + torch.testing.assert_close(data["int_a.values"], expected_int_a_values) + torch.testing.assert_close(data["int_b.values"], expected_int_b_values) + torch.testing.assert_close(data["lookup_a.values"], expected_lookup_a_values) + torch.testing.assert_close( + data["click_seq__cat_a.values"], expected_seq_cat_a_values + ) + torch.testing.assert_close( + data["click_seq__cat_a.lengths"], expected_seq_cat_a_seq_lengths + ) + torch.testing.assert_close( + data["click_seq__int_a.values"], expected_seq_int_a_values + ) + torch.testing.assert_close( + data["click_seq__int_a.lengths"], expected_seq_int_a_seq_lengths + ) + torch.testing.assert_close(data["label"], expected_label) + @parameterized.expand( [ [ @@ -540,7 +596,7 @@ def test_fg(self, fg_mode, weigted_id): "click_seq__int_a": pa.array(["14;15;16", "17", ""]), "label": pa.array([0, 0, 1], pa.int32()), }, - FgMode.ENCODED, + FgMode.FG_NONE, ], [ { @@ -553,7 +609,7 @@ def test_fg(self, fg_mode, weigted_id): "click_seq__int_a": pa.array(["14;15;16", "17", ""]), "label": pa.array([0, 0, 1], pa.int32()), }, - FgMode.DAG, + FgMode.FG_DAG, ], ] ) @@ -664,7 +720,7 @@ def test_input_tile(self, input_data, fg_mode): "click_seq__int_a": pa.array(["14;15;16", "17", ""]), "label": pa.array([0, 0, 1], pa.int32()), }, - FgMode.ENCODED, + FgMode.FG_NONE, ], [ { @@ -677,7 +733,7 @@ def test_input_tile(self, input_data, fg_mode): "click_seq__int_a": pa.array(["14;15;16", "17", ""]), "label": pa.array([0, 0, 1], pa.int32()), }, - FgMode.DAG, + FgMode.FG_DAG, ], ] ) diff --git a/tzrec/datasets/dataset.py b/tzrec/datasets/dataset.py index e3e7275..f9233aa 100644 --- a/tzrec/datasets/dataset.py +++ b/tzrec/datasets/dataset.py @@ -205,7 +205,7 @@ def __init__( ): self._selected_input_names = None - self._fg_encoded = data_config.fg_encoded + self._fg_mode = data_config.fg_mode self._fg_encoded_multival_sep = data_config.fg_encoded_multival_sep if mode != Mode.TRAIN and data_config.HasField("eval_batch_size"): @@ -235,7 +235,7 @@ def launch_sampler_cluster( self._batch_size, is_training=self._mode == Mode.TRAIN, multival_sep=self._fg_encoded_multival_sep - if self._fg_encoded + if self._fg_mode == data_pb2.FgMode.FG_NONE else chr(29), ) self._sampler.init_cluster(num_client_per_rank, client_id_bias, cluster) diff --git a/tzrec/datasets/dataset_test.py b/tzrec/datasets/dataset_test.py index 77ba9fa..f9ebc06 100644 --- a/tzrec/datasets/dataset_test.py +++ b/tzrec/datasets/dataset_test.py @@ -122,7 +122,7 @@ def test_dataset(self): data_config=data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=True, + fg_mode=data_pb2.FgModel.FG_NONE, label_fields=["label"], ), features=features, @@ -189,7 +189,7 @@ def test_dataset_with_sampler(self, force_base_data_group): data_config=data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=True, + fg_mode=data_pb2.FgModel.FG_NONE, label_fields=["label"], negative_sampler=sampler_pb2.NegativeSampler( input_path=f.name, @@ -284,7 +284,7 @@ def test_dataset_with_sample_mask(self): data_config=data_pb2.DataConfig( batch_size=32, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=True, + fg_mode=data_pb2.FgModel.FG_NONE, label_fields=["label"], sample_mask_prob=0.4, ), @@ -348,7 +348,7 @@ def test_dataset_with_neg_sample_mask(self): data_config=data_pb2.DataConfig( batch_size=32, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=True, + fg_mode=data_pb2.FgModel.FG_NONE, label_fields=["label"], negative_sample_mask_prob=0.4, negative_sampler=sampler_pb2.NegativeSampler( @@ -408,7 +408,7 @@ def test_dataset_predict_mode(self, debug_level): data_config=data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=True, + fg_mode=data_pb2.FgModel.FG_NONE, label_fields=[], ), features=features, @@ -501,7 +501,7 @@ def _childern(code): data_config=data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=True, + fg_mode=data_pb2.FgModel.FG_NONE, label_fields=["label"], tdm_sampler=sampler_pb2.TDMSampler( item_input_path=node.name, diff --git a/tzrec/datasets/odps_dataset_test.py b/tzrec/datasets/odps_dataset_test.py index 101f1e1..f6cea1f 100644 --- a/tzrec/datasets/odps_dataset_test.py +++ b/tzrec/datasets/odps_dataset_test.py @@ -154,13 +154,13 @@ def _create_test_table_and_feature_cfgs(self, has_lookup=True): @unittest.skipIf("ODPS_CONFIG_FILE_PATH" not in os.environ, "odps config not found") def test_odps_dataset(self, is_orderby_partition): feature_cfgs = self._create_test_table_and_feature_cfgs() - features = create_features(feature_cfgs, fg_mode=FgMode.DAG) + features = create_features(feature_cfgs, fg_mode=FgMode.FG_DAG) dataset = OdpsDataset( data_config=data_pb2.DataConfig( batch_size=8196, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=False, + fg_mode=data_pb2.FgMode.FG_DAG, label_fields=["label"], is_orderby_partition=is_orderby_partition, odps_data_quota_name="", @@ -220,13 +220,13 @@ def test_odps_dataset_with_sampler(self): writer.write([[i, 1.0, f"{i}:4:5.0"] for i in range(10000)]) features = create_features( - feature_cfgs, fg_mode=FgMode.DAG, neg_fields=["id_a", "raw_c", "raw_d"] + feature_cfgs, fg_mode=FgMode.FG_DAG, neg_fields=["id_a", "raw_c", "raw_d"] ) dataset = OdpsDataset( data_config=data_pb2.DataConfig( batch_size=8196, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=False, + fg_mode=data_pb2.FgMode.FG_DAG, label_fields=["label"], odps_data_quota_name="", negative_sampler=sampler_pb2.NegativeSampler( diff --git a/tzrec/datasets/odps_dataset_v1_test.py b/tzrec/datasets/odps_dataset_v1_test.py index 17ce35c..8f284aa 100644 --- a/tzrec/datasets/odps_dataset_v1_test.py +++ b/tzrec/datasets/odps_dataset_v1_test.py @@ -80,7 +80,7 @@ def test_odps_dataset(self): data_config=data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.OdpsDataset, - fg_encoded=True, + fg_mode=data_pb2.FgMode.FG_NONE, label_fields=["label"], ), features=features, diff --git a/tzrec/datasets/parquet_dataset_test.py b/tzrec/datasets/parquet_dataset_test.py index 7cdd2b5..e1f6b8b 100644 --- a/tzrec/datasets/parquet_dataset_test.py +++ b/tzrec/datasets/parquet_dataset_test.py @@ -69,7 +69,7 @@ def test_parquet_dataset(self, num_rows): data_config=data_pb2.DataConfig( batch_size=4, dataset_type=data_pb2.DatasetType.ParquetDataset, - fg_encoded=True, + fg_mode=data_pb2.FgMode.FG_NONE, label_fields=["label"], ), features=features, diff --git a/tzrec/features/combo_feature.py b/tzrec/features/combo_feature.py index c608c1d..549a668 100644 --- a/tzrec/features/combo_feature.py +++ b/tzrec/features/combo_feature.py @@ -36,13 +36,13 @@ class ComboFeature(IdFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -88,13 +88,13 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: # input feature is already bucktized feat = input_data[self.name] parsed_feat = _parse_fg_encoded_sparse_feature_impl( self.name, feat, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feats = [] for name in self.inputs: x = input_data[name] @@ -103,6 +103,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: input_feats.append(x.tolist()) values, lengths = self._fg_op.to_bucketized_jagged_tensor(input_feats) parsed_feat = SparseData(name=self.name, values=values, lengths=lengths) + else: + raise ValueError( + "fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def _build_side_inputs(self) -> List[Tuple[str, str]]: diff --git a/tzrec/features/combo_feature_test.py b/tzrec/features/combo_feature_test.py index 483eecf..903af2d 100644 --- a/tzrec/features/combo_feature_test.py +++ b/tzrec/features/combo_feature_test.py @@ -108,7 +108,7 @@ def test_combo_feature_with_hash_bucket_size( ) ) combo_feat = combo_feature_lib.ComboFeature( - combo_feat_cfg, fg_mode=FgMode.NORMAL + combo_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(combo_feat.output_dim, 16) self.assertEqual(combo_feat.is_sparse, True) @@ -162,7 +162,7 @@ def test_combo_feature_with_vocab_list( ) ) combo_feat = combo_feature_lib.ComboFeature( - combo_feat_cfg, fg_mode=FgMode.NORMAL + combo_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(combo_feat.output_dim, 16) self.assertEqual(combo_feat.is_sparse, True) diff --git a/tzrec/features/expr_feature.py b/tzrec/features/expr_feature.py index 6233bf6..39009e5 100644 --- a/tzrec/features/expr_feature.py +++ b/tzrec/features/expr_feature.py @@ -34,13 +34,13 @@ class ExprFeature(RawFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -73,7 +73,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: # input feature is already lookuped feat = input_data[self.name] if self.is_sparse: @@ -84,7 +84,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: parsed_feat = _parse_fg_encoded_dense_feature_impl( self.name, feat, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feats = [input_data[x].tolist() for x in self.inputs] if self.is_sparse: values, lengths = self._fg_op.to_bucketized_jagged_tensor(input_feats) @@ -92,6 +92,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: else: values = self._fg_op.transform(input_feats) parsed_feat = DenseData(name=self.name, values=values) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def fg_json(self) -> List[Dict[str, Any]]: diff --git a/tzrec/features/expr_feature_test.py b/tzrec/features/expr_feature_test.py index 556a38d..56dabd7 100644 --- a/tzrec/features/expr_feature_test.py +++ b/tzrec/features/expr_feature_test.py @@ -110,7 +110,9 @@ def test_expr_feature_dense( default_value=default_value, ) ) - expr_feat = expr_feature_lib.ExprFeature(expr_feat_cfg, fg_mode=FgMode.NORMAL) + expr_feat = expr_feature_lib.ExprFeature( + expr_feat_cfg, fg_mode=FgMode.FG_NORMAL + ) self.assertEqual(expr_feat.output_dim, 1) self.assertEqual(expr_feat.is_sparse, False) self.assertEqual(expr_feat.inputs, ["a", "b"]) @@ -154,7 +156,9 @@ def test_expr_feature_with_boundaries( default_value=default_value, ) ) - expr_feat = expr_feature_lib.ExprFeature(expr_feat_cfg, fg_mode=FgMode.NORMAL) + expr_feat = expr_feature_lib.ExprFeature( + expr_feat_cfg, fg_mode=FgMode.FG_NORMAL + ) self.assertEqual(expr_feat.output_dim, 16) self.assertEqual(expr_feat.is_sparse, True) self.assertEqual(expr_feat.inputs, ["a", "b"]) @@ -216,7 +220,9 @@ def test_expr_feature_dot( default_value=default_value, ) ) - expr_feat = expr_feature_lib.ExprFeature(expr_feat_cfg, fg_mode=FgMode.NORMAL) + expr_feat = expr_feature_lib.ExprFeature( + expr_feat_cfg, fg_mode=FgMode.FG_NORMAL + ) self.assertEqual(expr_feat.output_dim, 16) self.assertEqual(expr_feat.is_sparse, True) self.assertEqual(expr_feat.inputs, ["a", "b"]) diff --git a/tzrec/features/feature.py b/tzrec/features/feature.py index 3a895c3..b4391ef 100644 --- a/tzrec/features/feature.py +++ b/tzrec/features/feature.py @@ -13,7 +13,6 @@ import shutil from collections import OrderedDict from copy import copy -from enum import Enum from functools import partial # NOQA from typing import Any, Dict, List, Optional, Tuple, Union @@ -47,6 +46,7 @@ ParsedData, SparseData, ) +from tzrec.protos.data_pb2 import FgMode from tzrec.protos.feature_pb2 import FeatureConfig, SequenceFeature from tzrec.utils import config_util from tzrec.utils.load_class import get_register_class_meta @@ -54,15 +54,6 @@ _FEATURE_CLASS_MAP = {} _meta_cls = get_register_class_meta(_FEATURE_CLASS_MAP) - -class FgMode(Enum): - """ENCODED/NORMAL/DAG Mode.""" - - ENCODED = 1 - NORMAL = 2 - DAG = 3 - - MAX_HASH_BUCKET_SIZE = 2**63 - 1 @@ -211,13 +202,13 @@ class BaseFeature(object, metaclass=_meta_cls): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: fc_type = feature_config.WhichOneof("feature") @@ -225,7 +216,6 @@ def __init__( self.config = getattr(self._feature_config, fc_type) self.fg_mode = fg_mode - self.fg_encoded = fg_mode == FgMode.ENCODED self._fg_op = None self._is_neg = False @@ -238,7 +228,7 @@ def __init__( self._fg_encoded_kwargs = {} self._fg_encoded_multival_sep = fg_encoded_multival_sep or chr(3) - if self.fg_mode == FgMode.ENCODED: + if self.fg_mode == FgMode.FG_NONE: if self.config.HasField("fg_encoded_default_value"): self._fg_encoded_kwargs["default_value"] = ( self.fg_encoded_default_value() @@ -255,7 +245,7 @@ def __init__( ) from None self._fg_encoded_kwargs["multival_sep"] = self._fg_encoded_multival_sep - if self.fg_mode == FgMode.NORMAL: + if self.fg_mode == FgMode.FG_NORMAL: self.init_fg() @property @@ -439,7 +429,7 @@ def mc_module(self, device: torch.device) -> Optional[ManagedCollisionModule]: def inputs(self) -> List[str]: """Input field names.""" if not self._inputs: - if self.fg_encoded: + if self.fg_mode in [FgMode.FG_NONE, FgMode.FG_BUCKETIZE]: self._inputs = [self.name] else: self._inputs = [v for _, v in self.side_inputs] @@ -579,7 +569,7 @@ def __del__(self) -> None: def create_features( feature_configs: List[FeatureConfig], - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, neg_fields: Optional[List[str]] = None, fg_encoded_multival_sep: Optional[str] = None, force_base_data_group: bool = False, @@ -590,7 +580,7 @@ def create_features( feature_configs (list): list of feature_config. fg_mode (FgMode): input data fg mode. neg_fields (list, optional): negative sampled input fields. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE force_base_data_group (bool): force padding data into same data group with same batch_size. diff --git a/tzrec/features/feature_test.py b/tzrec/features/feature_test.py index 42aef3c..a98d971 100644 --- a/tzrec/features/feature_test.py +++ b/tzrec/features/feature_test.py @@ -310,7 +310,7 @@ def test_create_fg_json(self, with_asset_dir=False): asset_dir = self.test_dir token_file = "token_g_tokenizer.json" feature_cfgs = self._create_test_feature_cfgs() - features = feature_lib.create_features(feature_cfgs, fg_mode=FgMode.DAG) + features = feature_lib.create_features(feature_cfgs, fg_mode=FgMode.FG_DAG) fg_json = feature_lib.create_fg_json(features, asset_dir=asset_dir) self.maxDiff = None self.assertEqual( @@ -438,7 +438,7 @@ def test_create_fg_json(self, with_asset_dir=False): @parameterized.expand([[False], [True]]) def test_create_feauture_configs(self, with_asset_dir=False): feature_cfgs = self._create_test_feature_cfgs() - features = feature_lib.create_features(feature_cfgs, fg_mode=FgMode.DAG) + features = feature_lib.create_features(feature_cfgs, fg_mode=FgMode.FG_DAG) asset_dir = None token_file = "data/test/tokenizer.json" diff --git a/tzrec/features/id_feature.py b/tzrec/features/id_feature.py index 5102dd5..6981bc5 100644 --- a/tzrec/features/id_feature.py +++ b/tzrec/features/id_feature.py @@ -36,13 +36,13 @@ class IdFeature(BaseFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -100,7 +100,7 @@ def num_embeddings(self) -> int: def inputs(self) -> List[str]: """Input field names.""" if not self._inputs: - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: if self.is_weighted: self._inputs = [f"{self.name}__values", f"{self.name}__weights"] else: @@ -122,7 +122,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: feat = input_data[self.inputs[0]] weight = None if len(self.inputs) == 2: @@ -130,7 +130,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: parsed_feat = _parse_fg_encoded_sparse_feature_impl( self.name, feat, weight=weight, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feat = input_data[self.inputs[0]] if pa.types.is_list(input_feat.type): input_feat = input_feat.fill_null([]) @@ -145,6 +145,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: parsed_feat = SparseData( name=self.name, values=values, lengths=lengths, weights=weights ) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def fg_json(self) -> List[Dict[str, Any]]: diff --git a/tzrec/features/id_feature_test.py b/tzrec/features/id_feature_test.py index 389ce9c..15e88cd 100644 --- a/tzrec/features/id_feature_test.py +++ b/tzrec/features/id_feature_test.py @@ -194,7 +194,7 @@ def test_id_feature_with_weighted(self): weighted=True, ) ) - id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.NORMAL) + id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.FG_NORMAL) self.assertEqual(id_feat.inputs, ["cate"]) input_data = { @@ -235,7 +235,7 @@ def test_id_feature_with_hash_bucket_size( default_value=default_value, ) ) - id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.NORMAL) + id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.FG_NORMAL) self.assertEqual(id_feat.inputs, ["id_input"]) expected_emb_bag_config = EmbeddingBagConfig( @@ -285,7 +285,7 @@ def test_id_feature_with_vocab_list( default_value=default_value, ) ) - id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.NORMAL) + id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.FG_NORMAL) expected_emb_bag_config = EmbeddingBagConfig( num_embeddings=4, @@ -329,7 +329,7 @@ def test_id_feature_with_vocab_dict( default_value=default_value, ) ) - id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.NORMAL) + id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.FG_NORMAL) expected_emb_bag_config = EmbeddingBagConfig( num_embeddings=3, @@ -369,7 +369,7 @@ def test_id_feature_with_num_buckets( default_value=default_value, ) ) - id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.NORMAL) + id_feat = id_feature_lib.IdFeature(id_feat_cfg, fg_mode=FgMode.FG_NORMAL) expected_emb_bag_config = EmbeddingBagConfig( num_embeddings=100, diff --git a/tzrec/features/lookup_feature.py b/tzrec/features/lookup_feature.py index 9d4aac9..c7be0ce 100644 --- a/tzrec/features/lookup_feature.py +++ b/tzrec/features/lookup_feature.py @@ -39,13 +39,13 @@ class LookupFeature(BaseFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -119,7 +119,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: # input feature is already lookuped feat = input_data[self.name] if self.is_sparse: @@ -130,7 +130,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: parsed_feat = _parse_fg_encoded_dense_feature_impl( self.name, feat, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feats = [] for name in self.inputs: x = input_data[name] @@ -162,6 +162,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: else: values = self._fg_op.transform(*input_feats) parsed_feat = DenseData(name=self.name, values=values) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def fg_json(self) -> List[Dict[str, Any]]: diff --git a/tzrec/features/lookup_feature_test.py b/tzrec/features/lookup_feature_test.py index 78c8fbf..a4c4618 100644 --- a/tzrec/features/lookup_feature_test.py +++ b/tzrec/features/lookup_feature_test.py @@ -211,7 +211,7 @@ def test_lookup_feature_dense( ) ) lookup_feat = lookup_feature_lib.LookupFeature( - lookup_feat_cfg, fg_mode=FgMode.NORMAL + lookup_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual( lookup_feat.fg_encoded_default_value(), expected_fg_encoded_default @@ -253,7 +253,7 @@ def test_lookup_feature_with_boundary( ) ) lookup_feat = lookup_feature_lib.LookupFeature( - lookup_feat_cfg, fg_mode=FgMode.NORMAL + lookup_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(lookup_feat.output_dim, 16) self.assertEqual(lookup_feat.is_sparse, True) @@ -310,7 +310,7 @@ def test_lookup_feature_with_num_buckets( ) ) lookup_feat = lookup_feature_lib.LookupFeature( - lookup_feat_cfg, fg_mode=FgMode.NORMAL + lookup_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(lookup_feat.output_dim, 16) self.assertEqual(lookup_feat.is_sparse, True) @@ -366,7 +366,7 @@ def test_lookup_feature_with_hash_bucket_size( ) ) lookup_feat = lookup_feature_lib.LookupFeature( - lookup_feat_cfg, fg_mode=FgMode.NORMAL + lookup_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(lookup_feat.output_dim, 16) self.assertEqual(lookup_feat.is_sparse, True) diff --git a/tzrec/features/match_feature.py b/tzrec/features/match_feature.py index 4fbc557..895a714 100644 --- a/tzrec/features/match_feature.py +++ b/tzrec/features/match_feature.py @@ -39,13 +39,13 @@ class MatchFeature(BaseFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -125,7 +125,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: # input feature is already lookuped feat = input_data[self.name] if self.is_sparse: @@ -136,7 +136,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: parsed_feat = _parse_fg_encoded_dense_feature_impl( self.name, feat, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: inputs = copy.copy(self.inputs) input_feats = [input_data[inputs.pop(0)].cast(pa.string()).tolist()] if not self._wildcard_pkey: @@ -153,6 +153,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: else: values = self._fg_op.transform(*input_feats) parsed_feat = DenseData(name=self.name, values=values) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def fg_json(self) -> List[Dict[str, Any]]: diff --git a/tzrec/features/match_feature_test.py b/tzrec/features/match_feature_test.py index e2f2081..9414b35 100644 --- a/tzrec/features/match_feature_test.py +++ b/tzrec/features/match_feature_test.py @@ -63,7 +63,7 @@ def test_match_feature_dense(self): ) ) match_feat = match_feature_lib.MatchFeature( - match_feat_cfg, fg_mode=FgMode.NORMAL + match_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(match_feat.output_dim, 1) self.assertEqual(match_feat.is_sparse, False) @@ -112,7 +112,7 @@ def test_match_feature_with_boundary( ) ) match_feat = match_feature_lib.MatchFeature( - match_feat_cfg, fg_mode=FgMode.NORMAL + match_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(match_feat.output_dim, 16) self.assertEqual(match_feat.is_sparse, True) @@ -166,7 +166,7 @@ def test_match_feature_with_num_buckets( ) ) match_feat = match_feature_lib.MatchFeature( - match_feat_cfg, fg_mode=FgMode.NORMAL + match_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(match_feat.output_dim, 16) self.assertEqual(match_feat.is_sparse, True) @@ -220,7 +220,7 @@ def test_match_feature_with_hash_bucket_size( ) ) match_feat = match_feature_lib.MatchFeature( - match_feat_cfg, fg_mode=FgMode.NORMAL + match_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(match_feat.output_dim, 16) self.assertEqual(match_feat.is_sparse, True) diff --git a/tzrec/features/overlap_feature.py b/tzrec/features/overlap_feature.py index b9f7263..2f58ae5 100644 --- a/tzrec/features/overlap_feature.py +++ b/tzrec/features/overlap_feature.py @@ -34,13 +34,13 @@ class OverlapFeature(RawFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -73,7 +73,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: # input feature is already lookuped feat = input_data[self.name] if self.is_sparse: @@ -84,7 +84,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: parsed_feat = _parse_fg_encoded_dense_feature_impl( self.name, feat, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feats = [] for name in self.inputs: x = input_data[name] @@ -97,6 +97,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: else: values = self._fg_op.transform(input_feats) parsed_feat = DenseData(name=self.name, values=values) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def fg_json(self) -> List[Dict[str, Any]]: diff --git a/tzrec/features/overlap_feature_test.py b/tzrec/features/overlap_feature_test.py index 8177ec3..e5fbca4 100644 --- a/tzrec/features/overlap_feature_test.py +++ b/tzrec/features/overlap_feature_test.py @@ -113,7 +113,7 @@ def test_overlap_feature_dense( ) ) overlap_feat = overlap_feature_lib.OverlapFeature( - overlap_feat_cfg, fg_mode=FgMode.NORMAL + overlap_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(overlap_feat.output_dim, 1) self.assertEqual(overlap_feat.is_sparse, False) @@ -150,7 +150,7 @@ def test_overlap_feature_with_boundaries( ) ) overlap_feat = overlap_feature_lib.OverlapFeature( - overlap_feat_cfg, fg_mode=FgMode.NORMAL + overlap_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(overlap_feat.output_dim, 16) self.assertEqual(overlap_feat.is_sparse, True) diff --git a/tzrec/features/raw_feature.py b/tzrec/features/raw_feature.py index 9b16440..ce4c2fe 100644 --- a/tzrec/features/raw_feature.py +++ b/tzrec/features/raw_feature.py @@ -33,13 +33,13 @@ class RawFeature(BaseFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -82,7 +82,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: feat = input_data[self.name] if self.is_sparse: parsed_feat = _parse_fg_encoded_sparse_feature_impl( @@ -92,7 +92,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: parsed_feat = _parse_fg_encoded_dense_feature_impl( self.name, feat, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feat = input_data[self.inputs[0]] if pa.types.is_list(input_feat.type): input_feat = input_feat.fill_null([]) @@ -103,6 +103,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: else: values = self._fg_op.transform(input_feat) parsed_feat = DenseData(name=self.name, values=values) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def fg_json(self) -> List[Dict[str, Any]]: diff --git a/tzrec/features/raw_feature_test.py b/tzrec/features/raw_feature_test.py index c9e3fd2..5ce21df 100644 --- a/tzrec/features/raw_feature_test.py +++ b/tzrec/features/raw_feature_test.py @@ -172,7 +172,7 @@ def test_raw_feature_dense( value_dim=value_dim, ) ) - raw_feat = raw_feature_lib.RawFeature(raw_feat_cfg, fg_mode=FgMode.NORMAL) + raw_feat = raw_feature_lib.RawFeature(raw_feat_cfg, fg_mode=FgMode.FG_NORMAL) np.testing.assert_allclose( raw_feat.fg_encoded_default_value(), expected_fg_default ) @@ -241,7 +241,7 @@ def test_raw_feature_with_boundaries( value_dim=value_dim, ) ) - raw_feat = raw_feature_lib.RawFeature(raw_feat_cfg, fg_mode=FgMode.NORMAL) + raw_feat = raw_feature_lib.RawFeature(raw_feat_cfg, fg_mode=FgMode.FG_NORMAL) fg_default = raw_feat.fg_encoded_default_value() if expected_fg_default: np.testing.assert_allclose(fg_default, expected_fg_default) diff --git a/tzrec/features/sequence_feature.py b/tzrec/features/sequence_feature.py index fa7c9c6..d59a892 100644 --- a/tzrec/features/sequence_feature.py +++ b/tzrec/features/sequence_feature.py @@ -169,7 +169,7 @@ class SequenceIdFeature(IdFeature): sequence_length (int): max sequence length. sequence_pk (str): sequence primary key name for serving. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( @@ -179,7 +179,7 @@ def __init__( sequence_delim: Optional[str] = None, sequence_length: Optional[int] = None, sequence_pk: Optional[str] = None, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: fc_type = feature_config.WhichOneof("feature") @@ -238,7 +238,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: feat = input_data[self.name] parsed_feat = _parse_fg_encoded_sequence_sparse_feature_impl( self.name, @@ -246,7 +246,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: sequence_delim=self.sequence_delim, **self._fg_encoded_kwargs, ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feat = input_data[self.inputs[0]] if pa.types.is_list(input_feat.type): input_feat = input_feat.fill_null([]) @@ -260,6 +260,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: lengths=key_lengths, seq_lengths=seq_lengths, ) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def init_fg(self) -> None: @@ -341,7 +345,7 @@ class SequenceRawFeature(RawFeature): sequence_length (int): max sequence length. sequence_pk (str): sequence primary key name for serving. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( @@ -351,7 +355,7 @@ def __init__( sequence_delim: Optional[str] = None, sequence_length: Optional[int] = None, sequence_pk: Optional[str] = None, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: fc_type = feature_config.WhichOneof("feature") @@ -410,7 +414,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: feat = input_data[self.name] if self.is_sparse: parsed_feat = _parse_fg_encoded_sequence_sparse_feature_impl( @@ -427,7 +431,7 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: value_dim=self.config.value_dim, **self._fg_encoded_kwargs, ) - else: + elif self.fg_mode == FgMode.FG_NONE: input_feat = input_data[self.inputs[0]] if pa.types.is_list(input_feat.type): input_feat = input_feat.fill_null([]) @@ -447,6 +451,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: values=values, seq_lengths=lengths, ) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def init_fg(self) -> None: diff --git a/tzrec/features/sequence_feature_test.py b/tzrec/features/sequence_feature_test.py index 3d0dc13..bd36355 100644 --- a/tzrec/features/sequence_feature_test.py +++ b/tzrec/features/sequence_feature_test.py @@ -151,7 +151,7 @@ def test_sequence_id_feature_with_hash_bucket_size( sequence_name="click_50_seq", sequence_delim=";", sequence_length=50, - fg_mode=FgMode.NORMAL, + fg_mode=FgMode.FG_NORMAL, ) self.assertEqual(seq_feat.output_dim, 16) self.assertEqual(seq_feat.is_sparse, True) @@ -196,7 +196,7 @@ def test_simple_sequence_id_feature_with_hash_bucket_size( ) seq_feat = sequence_feature_lib.SequenceIdFeature( seq_feat_cfg, - fg_mode=FgMode.NORMAL, + fg_mode=FgMode.FG_NORMAL, ) self.assertEqual(seq_feat.output_dim, 16) self.assertEqual(seq_feat.is_sparse, True) @@ -273,7 +273,7 @@ def test_sequence_id_feature_with_num_buckets( sequence_name="click_50_seq", sequence_delim=";", sequence_length=50, - fg_mode=FgMode.NORMAL, + fg_mode=FgMode.FG_NORMAL, ) self.assertEqual(seq_feat.output_dim, 16) self.assertEqual(seq_feat.is_sparse, True) @@ -310,7 +310,7 @@ def test_sequence_id_feature_with_vocab_list(self): sequence_name="click_50_seq", sequence_delim="|", sequence_length=50, - fg_mode=FgMode.NORMAL, + fg_mode=FgMode.FG_NORMAL, ) input_data = {"click_50_seq__id_str": pa.array(["c||a|b|b|", "", "a|b||c"])} parsed_feat = seq_feat.parse(input_data) @@ -561,7 +561,7 @@ def test_sequence_sequence_raw_feature_dense( sequence_name="click_50_seq", sequence_delim=";", sequence_length=50, - fg_mode=FgMode.NORMAL, + fg_mode=FgMode.FG_NORMAL, ) self.assertEqual(seq_feat.output_dim, value_dim) self.assertEqual(seq_feat.is_sparse, False) @@ -608,7 +608,7 @@ def test_simple_sequence_sequence_raw_feature_dense( ) seq_feat = sequence_feature_lib.SequenceRawFeature( seq_feat_cfg, - fg_mode=FgMode.NORMAL, + fg_mode=FgMode.FG_NORMAL, ) self.assertEqual(seq_feat.output_dim, value_dim) self.assertEqual(seq_feat.is_sparse, False) @@ -667,7 +667,7 @@ def test_sequence_sequence_raw_feature_with_boundaries( sequence_name="click_50_seq", sequence_delim=";", sequence_length=50, - fg_mode=FgMode.NORMAL, + fg_mode=FgMode.FG_NORMAL, ) self.assertEqual(seq_feat.output_dim, 16) self.assertEqual(seq_feat.is_sparse, True) diff --git a/tzrec/features/tokenize_feature.py b/tzrec/features/tokenize_feature.py index fa14e96..e6db5e7 100644 --- a/tzrec/features/tokenize_feature.py +++ b/tzrec/features/tokenize_feature.py @@ -40,13 +40,13 @@ class TokenizeFeature(IdFeature): Args: feature_config (FeatureConfig): a instance of feature config. fg_mode (FgMode): input data fg mode. - fg_encoded_multival_sep (str, optional): multival_sep when fg_encoded=true + fg_encoded_multival_sep (str, optional): multival_sep when fg_mode=FG_NONE """ def __init__( self, feature_config: FeatureConfig, - fg_mode: FgMode = FgMode.ENCODED, + fg_mode: FgMode = FgMode.FG_NONE, fg_encoded_multival_sep: Optional[str] = None, ) -> None: super().__init__(feature_config, fg_mode, fg_encoded_multival_sep) @@ -74,13 +74,13 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: Return: parsed feature data. """ - if self.fg_encoded: + if self.fg_mode == FgMode.FG_NONE: # input feature is already bucktized feat = input_data[self.name] parsed_feat = _parse_fg_encoded_sparse_feature_impl( self.name, feat, **self._fg_encoded_kwargs ) - else: + elif self.fg_mode == FgMode.FG_NORMAL: input_feat = input_data[self.inputs[0]] if pa.types.is_list(input_feat.type): input_feat = input_feat.fill_null([]) @@ -93,6 +93,10 @@ def _parse(self, input_data: Dict[str, pa.Array]) -> ParsedData: else: values, lengths = self._fg_op.to_bucketized_jagged_tensor([input_feat]) parsed_feat = SparseData(name=self.name, values=values, lengths=lengths) + else: + raise ValueError( + f"fg_mode: {self.fg_mode} is not supported without fg handler." + ) return parsed_feat def init_fg(self) -> None: diff --git a/tzrec/features/tokenize_feature_test.py b/tzrec/features/tokenize_feature_test.py index 4d5755f..3efb3af 100644 --- a/tzrec/features/tokenize_feature_test.py +++ b/tzrec/features/tokenize_feature_test.py @@ -104,7 +104,7 @@ def test_tokenize_feature( ) token_feat_cfg.tokenize_feature.text_normalizer.CopyFrom(text_norm) token_feat = tokenize_feature_lib.TokenizeFeature( - token_feat_cfg, fg_mode=FgMode.NORMAL + token_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(token_feat.inputs, ["token_input"]) @@ -165,7 +165,7 @@ def test_tokenize_feature_sentencepiece( ) ) token_feat = tokenize_feature_lib.TokenizeFeature( - token_feat_cfg, fg_mode=FgMode.NORMAL + token_feat_cfg, fg_mode=FgMode.FG_NORMAL ) self.assertEqual(token_feat.inputs, ["token_input"]) diff --git a/tzrec/main.py b/tzrec/main.py index b176d4a..7a9ab4d 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -56,7 +56,6 @@ from tzrec.datasets.utils import Batch, RecordBatchTensor from tzrec.features.feature import ( BaseFeature, - FgMode, create_feature_configs, create_features, create_fg_json, @@ -111,15 +110,9 @@ def _create_features( getattr(data_config, data_config.WhichOneof("sampler")).attr_fields ) - if data_config.fg_encoded: - fg_mode = FgMode.ENCODED - elif data_config.fg_threads > 0: - fg_mode = FgMode.DAG - else: - fg_mode = FgMode.NORMAL features = create_features( feature_configs, - fg_mode=fg_mode, + fg_mode=data_config.fg_mode, neg_fields=neg_fields, fg_encoded_multival_sep=data_config.fg_encoded_multival_sep, force_base_data_group=data_config.force_base_data_group, diff --git a/tzrec/predict.py b/tzrec/predict.py index de4ec9e..216e503 100644 --- a/tzrec/predict.py +++ b/tzrec/predict.py @@ -85,7 +85,7 @@ "--edit_config_json", type=str, default=None, - help='edit pipeline config str, example: {"data_config.fg_encoded":true}', + help='edit pipeline config str, example: {"data_config.fg_mode":"FG_DAG"}', ) args, extra_args = parser.parse_known_args() diff --git a/tzrec/protos/data.proto b/tzrec/protos/data.proto index 4e62d9b..a92b06e 100644 --- a/tzrec/protos/data.proto +++ b/tzrec/protos/data.proto @@ -18,6 +18,21 @@ enum FieldType { DOUBLE = 5; } +enum FgMode { + // input data is feature generate encoded, + // we do not do fg + FG_NONE = 1; + // input data is raw feature, + // we use python to run feature generate + FG_NORMAL = 2; + // input data is raw feature, + // we use fg_handler to run feature generate + FG_DAG = 3; + // input data is after feature generate but before do bucketize, + // we do bucketize only + FG_BUCKETIZE = 4; +} + message Field { required string input_name = 1; // only need specify it when use CsvDataset and @@ -32,11 +47,15 @@ message DataConfig { // dataset type. required DatasetType dataset_type = 2 [default = OdpsDataset]; + // [deprecated] please use fg_mode. // input data is feature generate encoded or not. // if fg_encoded = true, you should do fg offline first, // and set fg_encoded_multival_sep for split multi-val feature optional bool fg_encoded = 3 [default = true]; + // fg run mode. + optional FgMode fg_mode = 20 [default = FG_NONE]; + // separator for multi-val feature in fg encoded input data optional string fg_encoded_multival_sep = 4 [default = '\x03']; diff --git a/tzrec/protos/feature.proto b/tzrec/protos/feature.proto index ce580dc..d2ad33f 100644 --- a/tzrec/protos/feature.proto +++ b/tzrec/protos/feature.proto @@ -72,7 +72,7 @@ message IdFeature { // zero collision hash optional ZeroCollisionHash zch = 16; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -110,7 +110,7 @@ message RawFeature { // mask value in training progress optional bool use_mask = 14; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -148,7 +148,7 @@ message ComboFeature { // zero collision hash optional ZeroCollisionHash zch = 15; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -208,7 +208,7 @@ message LookupFeature { // zero collision hash optional ZeroCollisionHash zch = 26; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -270,7 +270,7 @@ message MatchFeature { // zero collision hash optional ZeroCollisionHash zch = 26; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -298,7 +298,7 @@ message ExprFeature { // mask value in training progress optional bool use_mask = 13; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -336,7 +336,7 @@ message OverlapFeature { // mask value in training progress optional bool use_mask = 14; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -395,7 +395,7 @@ message TokenizeFeature { // mask value in training progress optional bool use_mask = 15; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -439,7 +439,7 @@ message SequenceIdFeature { // zero collision hash optional ZeroCollisionHash zch = 21; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; @@ -480,7 +480,7 @@ message SequenceRawFeature { // mask value in training progress optional bool use_mask = 14; - // default value when fg_encoded = true, + // default value when fg_mode = FG_NONE, // when use pai-fg, you do not need to set the param. // when use own fg and data contain null value, you can set the param for fill null optional string fg_encoded_default_value = 30; diff --git a/tzrec/tests/configs/dbmtl_has_sequence_mock.config b/tzrec/tests/configs/dbmtl_has_sequence_mock.config index 184c222..2e0d091 100644 --- a/tzrec/tests/configs/dbmtl_has_sequence_mock.config +++ b/tzrec/tests/configs/dbmtl_has_sequence_mock.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/tests/configs/dbmtl_has_sequence_variational_dropout_mock.config b/tzrec/tests/configs/dbmtl_has_sequence_variational_dropout_mock.config index 83c140a..03b15d2 100644 --- a/tzrec/tests/configs/dbmtl_has_sequence_variational_dropout_mock.config +++ b/tzrec/tests/configs/dbmtl_has_sequence_variational_dropout_mock.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: true label_fields: "clk" label_fields: "buy" num_workers: 8 diff --git a/tzrec/tests/configs/dssm_fg_mock.config b/tzrec/tests/configs/dssm_fg_mock.config index d71d938..7594c5e 100644 --- a/tzrec/tests/configs/dssm_fg_mock.config +++ b/tzrec/tests/configs/dssm_fg_mock.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "clk" num_workers: 8 negative_sampler { diff --git a/tzrec/tests/configs/dssm_mock.config b/tzrec/tests/configs/dssm_mock.config index 31486a9..074defe 100644 --- a/tzrec/tests/configs/dssm_mock.config +++ b/tzrec/tests/configs/dssm_mock.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: true label_fields: "clk" num_workers: 8 negative_sampler { diff --git a/tzrec/tests/configs/dssm_v2_fg_mock.config b/tzrec/tests/configs/dssm_v2_fg_mock.config index 92ace7d..44371a0 100644 --- a/tzrec/tests/configs/dssm_v2_fg_mock.config +++ b/tzrec/tests/configs/dssm_v2_fg_mock.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "clk" num_workers: 8 force_base_data_group: true diff --git a/tzrec/tests/configs/dssm_variational_dropout_mock.config b/tzrec/tests/configs/dssm_variational_dropout_mock.config index 35e69af..97affa9 100644 --- a/tzrec/tests/configs/dssm_variational_dropout_mock.config +++ b/tzrec/tests/configs/dssm_variational_dropout_mock.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: true label_fields: "clk" num_workers: 8 negative_sampler { diff --git a/tzrec/tests/configs/multi_tower_din_fg_mock.config b/tzrec/tests/configs/multi_tower_din_fg_mock.config index 130def9..d0cd795 100644 --- a/tzrec/tests/configs/multi_tower_din_fg_mock.config +++ b/tzrec/tests/configs/multi_tower_din_fg_mock.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "clk" num_workers: 8 } diff --git a/tzrec/tests/configs/multi_tower_din_mock.config b/tzrec/tests/configs/multi_tower_din_mock.config index bc875a2..7e3a844 100644 --- a/tzrec/tests/configs/multi_tower_din_mock.config +++ b/tzrec/tests/configs/multi_tower_din_mock.config @@ -23,7 +23,6 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: true label_fields: "clk" num_workers: 8 } diff --git a/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config b/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config index d8d0c03..7e7bbda 100644 --- a/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config +++ b/tzrec/tests/configs/multi_tower_din_trt_fg_mock.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "clk" num_workers: 8 } diff --git a/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config b/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config index 0789a39..75af22d 100644 --- a/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config +++ b/tzrec/tests/configs/multi_tower_din_zch_fg_mock.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "clk" num_workers: 8 } diff --git a/tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config b/tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config index b1578d7..028677d 100644 --- a/tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config +++ b/tzrec/tests/configs/multi_tower_din_zch_trt_fg_mock.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 8192 dataset_type: ParquetDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "clk" num_workers: 8 } diff --git a/tzrec/tests/configs/tdm_fg_mock.config b/tzrec/tests/configs/tdm_fg_mock.config index 8ba6d56..d2f6935 100644 --- a/tzrec/tests/configs/tdm_fg_mock.config +++ b/tzrec/tests/configs/tdm_fg_mock.config @@ -23,7 +23,7 @@ eval_config { data_config { batch_size: 32 dataset_type: ParquetDataset - fg_encoded: false + fg_mode: FG_DAG label_fields: "clk" num_workers: 8 tdm_sampler { diff --git a/tzrec/tests/match_integration_test.py b/tzrec/tests/match_integration_test.py index 442c51a..92fcc41 100644 --- a/tzrec/tests/match_integration_test.py +++ b/tzrec/tests/match_integration_test.py @@ -30,7 +30,7 @@ def tearDown(self): if os.path.exists(self.test_dir): shutil.rmtree(self.test_dir) - def test_dssm_fg_encoded_train_eval_export(self): + def test_dssm_nofg_train_eval_export(self): self.success = utils.test_train_eval( "tzrec/tests/configs/dssm_mock.config", self.test_dir, item_id="item_id" ) @@ -62,7 +62,7 @@ def test_dssm_fg_encoded_train_eval_export(self): os.path.exists(os.path.join(self.test_dir, "export/item/scripted_model.pt")) ) - def test_dssm_fg_encoded_variational_dropout(self): + def test_dssm_nofg_variational_dropout(self): self.success = utils.test_train_eval( "tzrec/tests/configs/dssm_variational_dropout_mock.config", self.test_dir, diff --git a/tzrec/tests/rank_integration_test.py b/tzrec/tests/rank_integration_test.py index 6084fe4..e763bfa 100644 --- a/tzrec/tests/rank_integration_test.py +++ b/tzrec/tests/rank_integration_test.py @@ -41,9 +41,7 @@ def tearDown(self): os.environ.pop("INPUT_TILE", None) os.environ.pop("ENABLE_TRT", None) - def _test_rank_fg_encoded( - self, pipeline_config_path, reserved_columns, output_columns - ): + def _test_rank_nofg(self, pipeline_config_path, reserved_columns, output_columns): self.success = utils.test_train_eval(pipeline_config_path, self.test_dir) if self.success: self.success = utils.test_eval( @@ -71,14 +69,14 @@ def _test_rank_fg_encoded( ) def test_multi_tower_din_fg_encoded_train_eval_export(self): - self._test_rank_fg_encoded( + self._test_rank_nofg( "tzrec/tests/configs/multi_tower_din_mock.config", reserved_columns="clk", output_columns="probs", ) def test_dbmtl_has_sequence_fg_encoded_train_eval_export(self): - self._test_rank_fg_encoded( + self._test_rank_nofg( "tzrec/tests/configs/dbmtl_has_sequence_mock.config", reserved_columns="clk,buy", output_columns="probs_ctr,probs_cvr", diff --git a/tzrec/tests/utils.py b/tzrec/tests/utils.py index 2c1623e..791f35c 100644 --- a/tzrec/tests/utils.py +++ b/tzrec/tests/utils.py @@ -696,12 +696,12 @@ def load_config_for_test( features = create_features( list(pipeline_config.feature_configs), - fg_mode=FgMode.ENCODED if data_config.fg_encoded else FgMode.DAG, + fg_mode=data_config.fg_mode, ) data_config.num_workers = 2 num_parts = data_config.num_workers * 2 - if data_config.fg_encoded: + if data_config.fg_mode == FgMode.FG_NONE: inputs = build_mock_input_fg_encoded(features, user_id, item_id) item_inputs = inputs pipeline_config.train_input_path, _ = create_mock_data( @@ -1037,9 +1037,7 @@ def create_predict_data( ) features = create_features( pipeline_config.feature_configs, - fg_mode=FgMode.ENCODED - if pipeline_config.data_config.fg_encoded - else FgMode.DAG, + fg_mode=pipeline_config.data_config.fg_mode, ) user_inputs = [] for feature in features: diff --git a/tzrec/tools/convert_easyrec_config_to_tzrec_config.py b/tzrec/tools/convert_easyrec_config_to_tzrec_config.py index 6322eff..f92dba6 100644 --- a/tzrec/tools/convert_easyrec_config_to_tzrec_config.py +++ b/tzrec/tools/convert_easyrec_config_to_tzrec_config.py @@ -178,7 +178,6 @@ def _create_data_config(self, pipeline_config): self.easyrec_config.data_config.batch_size ) pipeline_config.data_config.dataset_type = DatasetType.OdpsDataset - pipeline_config.data_config.fg_encoded = True pipeline_config.data_config.label_fields.extend(label_fields) pipeline_config.data_config.num_workers = 8 pipeline_config.data_config.odps_data_quota_name = "" diff --git a/tzrec/tools/convert_easyrec_config_to_tzrec_config_test.py b/tzrec/tools/convert_easyrec_config_to_tzrec_config_test.py index 4a1d680..d590dd8 100644 --- a/tzrec/tools/convert_easyrec_config_to_tzrec_config_test.py +++ b/tzrec/tools/convert_easyrec_config_to_tzrec_config_test.py @@ -455,7 +455,6 @@ DATA_CONFIG = """data_config { batch_size: 4096 dataset_type: OdpsDataset - fg_encoded: true label_fields: "is_click_cover" label_fields: "is_click_video" num_workers: 8 diff --git a/tzrec/utils/config_util.py b/tzrec/utils/config_util.py index eed332d..96c2dec 100644 --- a/tzrec/utils/config_util.py +++ b/tzrec/utils/config_util.py @@ -18,6 +18,8 @@ from google.protobuf.message import Message from tzrec.protos import pipeline_pb2 +from tzrec.protos.data_pb2 import FgMode +from tzrec.utils.logging_util import logger def load_pipeline_config( @@ -41,6 +43,8 @@ def load_pipeline_config( ) else: text_format.Merge(f.read(), config, allow_unknown_field=allow_unknown_field) + # compatible for fg_encoded + config.data_config.fg_mode = _get_compatible_fg_mode(config.data_config) return config @@ -71,6 +75,24 @@ def which_msg(config: Message, oneof_group: str) -> str: return getattr(config, config.WhichOneof(oneof_group)).__class__.__name__ +def _get_compatible_fg_mode(data_config: Message) -> FgMode: + """Compat for fg_encoded.""" + if data_config.HasField("fg_encoded"): + logger.warning( + "data_config.fg_encoded will be deprecated, " + "please use data_config.fg_mode." + ) + if data_config.fg_encoded: + fg_mode = FgMode.FG_NONE + elif data_config.fg_threads > 0: + fg_mode = FgMode.FG_DAG + else: + fg_mode = FgMode.FG_NORMAL + else: + fg_mode = data_config.fg_mode + return fg_mode + + # pyre-ignore [24] def _get_basic_types() -> List[Type]: dtypes = [