Skip to content

Commit

Permalink
[bugfix] input_tile=3: make dataparser to get user feats before creat…
Browse files Browse the repository at this point in the history
…e model
  • Loading branch information
yjjinjie committed Nov 29, 2024
1 parent eddfefd commit 5cccc4f
Showing 1 changed file with 6 additions and 5 deletions.
11 changes: 6 additions & 5 deletions tzrec/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -819,6 +819,12 @@ def export(
# Build feature
features = _create_features(list(pipeline_config.feature_configs), data_config)

# make dataparser to get user feats before create model
data_config.num_workers = 1
dataloader = _get_dataloader(
data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT
)

# Build model
model = _create_model(
pipeline_config.model_config,
Expand Down Expand Up @@ -881,11 +887,6 @@ def export(
list(data_config.label_fields),
)

data_config.num_workers = 1
dataloader = _get_dataloader(
data_config, features, pipeline_config.train_input_path, mode=Mode.PREDICT
)

if isinstance(device_model, MatchModel):
for name, module in device_model.named_children():
if isinstance(module, MatchTower) or isinstance(module, MatchTowerWoEG):
Expand Down

0 comments on commit 5cccc4f

Please sign in to comment.