diff --git a/tzrec/main.py b/tzrec/main.py index d28864d..2619621 100644 --- a/tzrec/main.py +++ b/tzrec/main.py @@ -819,6 +819,12 @@ def export( # Build feature features = _create_features(list(pipeline_config.feature_configs), data_config) + # make dataparser to get user feats before create model + data_config.num_workers = 1 + dataloader = _get_dataloader( + data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT + ) + # Build model model = _create_model( pipeline_config.model_config, @@ -881,11 +887,6 @@ def export( list(data_config.label_fields), ) - data_config.num_workers = 1 - dataloader = _get_dataloader( - data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT - ) - if isinstance(device_model, MatchModel): for name, module in device_model.named_children(): if isinstance(module, MatchTower) or isinstance(module, MatchTowerWoEG):