From 5cccc4f9c7cedfb7b48120de00966fc28c65467b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=91=BE=E6=B4=81?= Date: Fri, 29 Nov 2024 09:55:59 +0800 Subject: [PATCH] [bugfix] input_tile=3: make dataparser to get user feats before create model --- tzrec/main.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/tzrec/main.py b/tzrec/main.py index d28864d..2619621 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -819,6 +819,12 @@ def export( # Build feature features = _create_features(list(pipeline_config.feature_configs), data_config) + # make dataparser to get user feats before create model + data_config.num_workers = 1 + dataloader = _get_dataloader( + data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT + ) + # Build model model = _create_model( pipeline_config.model_config, @@ -881,11 +887,6 @@ def export( list(data_config.label_fields), ) - data_config.num_workers = 1 - dataloader = _get_dataloader( - data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT - ) - if isinstance(device_model, MatchModel): for name, module in device_model.named_children(): if isinstance(module, MatchTower) or isinstance(module, MatchTowerWoEG):